Feb 10, 2021

Multi-class Classification using Bert with Kera's and Tensorflow

Problem Statment 

We will be implementing Multi-class classification using BERT with Keras and Tensorflow.

In the current scenario, we will be classifying document(s)/sentence(s) into one of 8 categories 0 to 7 i.e. trained Bert model will label each sentence in the text corpus to one of 0,1,2,3,4,5,6,7.

Important Details of the Solution

The complete code with output is available on my GitHub at:

https://github.com/srichallla/NLP/blob/main/Bert_kerasTF_multiclassification.ipynb

We will be using a pre-trained Bert model to extract embeddings for each sentence in the text corpus and then use these embeddings to train a text classification model. We then use this trained Bert model to classify text on an unseen test dataset.

Bert expects labels/categories to start from 0, instead of 1, else the classification task may not work as expected or can throw errors. If your dataset has labels starting from 0, we should modify them. In the current dataset, labels are starting from 0. So modifying them to start from 1 as below:

df3['label_encode'] = df3['Label'].map({'1':0,'2':1,'3':2,'4':3,'5':4,'6':5,'7':6,'8':7})

Here we are creating a new column in a dataframe to store the modified labels, instead of overwriting an existing one.

Also, the "label" column should be of type int or float. If the "label" column is of type obj OR string, it has to be converted to int or float. Else Bert will not work as expected or can throw errors. In the current dataset "label_encode" column is of type obj. So converting it into int type as below:

df3['label_encode'] = df3.label_encode.astype(int)

The important limitation of Bert is that the maximum length of each sentence/sequence in a dataset or text corpus for Bert should be 512 tokens. Here we are setting it to 200. If the sentence length is smaller than 200, it will be padded with zeros. If the length is bigger, sentence will be truncated.

Larger the sentence length, the training time increases.

Bert inputs to TFBertForSequenceClassification model:

1) input_ids: They are token indices, which are numerical representations of each word in a sentence. Bert Tokenizer chunks each sentence into words and replaces/maps each word in a sentence to a number from the "WordPiece vocabulary" dictionary, by means of a lookup, where the key is the word and value is its numerical index.

2) token_type_ids: Are all the same as we do not have question-answer or pairs of sentences. For the classification problem, we treat the whole sentence as one. So in our case, it's all zeros.

3) attention_mask: will tell the model that we should not focus attention on [PAD] tokens.

4) labels: Actual Labels (label_encode feature/column) from a given labeled dataset. Required for training and validation datasets. Not required for the test dataset, as the model should predict those.

we will use the encode_plus function, which does the above 3 steps for us. 4th step we need to handle manually.

tokenizer.encode_plus(sentence, 
                add_special_tokens = True, # add [CLS], [SEP]
                max_length = 200, # max length of the text that can go to BERT
                pad_to_max_length = True, # add [PAD] tokens
                return_attention_mask = True, # add attention mask to not focus on pad tokens
              )

For multi-class classification, we must specify the number of unique categories/labels to classify the sentences/documents, 8 in our case.

model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=8)

If you do not specify the "num_labels", predictions on the test dataset will all be nan.

Finally, we train the Bert model on our dataset and then use the trained model to predict the class label on the test dataset.

With a batch-size of 6, max sentence length of 200 and, 1 epoch we got 97% accuracy.