Requirements

July 27, 2024 ยท View on GitHub

Keep in mind: KANs may not be proper for text classification (or even NLP?), at least in our research. Extra experiments must be conducted to prove this.

This repo uses Kolmogorov-Arnold Networks (KANs) for text classification over GLUE tasks (RTE, CoLA, MRPC, etc). Our paper will be published in arXiv soon.

Requirements

Training

We use bert-base-cased as the pre-trained model for producing embeddings (pooled_outputs) in the training process. All models have 768 input size, 64 hidden neurons, and 2 output classes (0 & 1). The training was performed on Tesla V100 16GB, 10 epochs, lr = 2e-5 for all transformer models, and lr = 2e-3 for other models.

TransformerEfficientKAN

python run_train.py --mode "train" --network "trans_effi_kan" --em_model_name "bert-base-cased" --ds_name "mrpc" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 768 --n_hidden 64 --n_class 2

TransformerFastKAN

python run_train.py --mode "train" --network "trans_fast_kan" --em_model_name "bert-base-cased" --ds_name "mrpc" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 768 --n_hidden 64 --n_class 2

TransformerFasterKAN

python run_train.py --mode "train" --network "trans_faster_kan" --em_model_name "bert-base-cased" --ds_name "mrpc" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 768 --n_hidden 64 --n_class 2

TransformerMLP

python run_train.py --mode "train" --network "mlp" --em_model_name "bert-base-cased" --ds_name "mrpc" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 768 --n_hidden 64 --n_class 2

TransformerClassifier (with Dropout and Linear)

python run_train.py --mode "train" --network "classifier" --em_model_name "bert-base-cased" --ds_name "mrpc" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 768 --n_hidden 64 --n_class 2

Original KAN

The training takes a very long time when the model infers outputs with an input size of 768 (outputs = KAN(texts)). Therefore, we must reduce the embedding size from 768 to 8 (n_size*m_size) by using reduce_size() in utils.py. The smaller the input size, the faster the training time.

def reduce_size(embeddings, n_size = 1, m_size = 8):
    second_dim = list(embeddings.shape)[-1]
    first_dim = list(embeddings.shape)[0]
    embeddings = torch.reshape(embeddings, (first_dim, int(second_dim/(n_size*m_size)), n_size*m_size))
    embeddings = torch.sum(embeddings, (1), keepdim = True).squeeze()
    return embeddings

Then, we can reluctantly run the training:

python run_train.py --mode "train" --network "kan" --em_model_name "bert-base-cased" --ds_name "wnli" --epochs 10 --batch_size 4 --max_len 512 --n_size 1 --m_size 8 --n_hidden 64 --n_class 2 --device "cpu"

Parameters

  • mode: working mode ("train" or "test")
  • network: type of model (efficientkan, TransformerClassifier, mlp)
  • em_model_name: the model offers embeddings (BERT)
  • ds_name: dataset name
  • epochs: the number of epochs
  • batch_size: the training batch size
  • max_len: the maximum length of input text
  • n_size, m_size: We consider the input size a matrix with n_size x m_size. For example, BERT offers 768 input size (1 x 768).
  • n_hidden: The number of hidden neurons. We use only 1 hidden layer. You can modify the code for more layers.
  • n_class: The number of classes. For GLUE tasks, there are only 2 classes (0 & 1)
  • embed_type: the type of embeddings (pool, last hidden, or weight)
  • device: use "cuda" or "cpu"

Results

CoLA (10 epochs)

NetworkTraining AccuracyVal AccuracyTraining time (seconds)
trans_mlp0.98970.82822798
trans_classifier0.96190.82822802
trans_effi_kan0.96350.82922827
trans_fast_kan0.99490.82062831
trans_faster_kan0.97560.82152818
effi_kan0.7490.7458951
fast_kan0.75010.742937
faster_kan0.72350.7315924

MRPC (10 epochs)

NetworkTraining AccuracyVal AccuracyTraining time (seconds)
trans_mlp0.73770.86031195
trans_classifier0.98660.88481204
trans_effi_kan0.99860.86761219
trans_fast_kan0.94220.85541214
trans_faster_kan0.95910.87011207
effi_kan0.69550.7255407
fast_kan0.70090.7157401
faster_kan0.68480.7059395

RTE (10 epochs)

NetworkTraining AccuracyVal AccuracyTraining time (seconds)
trans_mlp0.93020.675821
trans_classifier0.84750.625818
trans_effi_kan0.90690.675826
trans_fast_kan0.93940.6071831
trans_faster_kan0.96390.6964829
effi_kan0.50040.5214277
fast_kan0.52690.5429273
faster_kan0.4960.5214269

References

Contact

If you have any questions, please contact: tahoangthang@gmail.com. If you want to know more about me, please visit website: https://tahoangthang.com.