tensorflowmachine-learning

Should the data in batch be balanced?


I'm training a deep learning model to predict three emotion(joy, sadness, anger) by feeding content of tweet.

I encounter a problem is that my model can learning well on sadness, joy but very bad on joy.

Confusion matrix in three emotion

I think the reason why is that my train dataset is unbalanced.

Data size in joy: 196952, sadness: 29407, anger: 42420

So when training a model, batch size contain too many joy dataset which make model only guess the answer is joy rather than others.

I want to fix this issue by balanced the data in each batch. That say batch size is 128, we randomly chose same amount of three emotion data. Prevent model to be dominant by data of joy.

Question is: Should the data in batch be balanced?

Other question is that, I randomly chose dataset, does this violate the definition of epoch?

Because epoch mean that reading all train dataset. When randomly chose, maybe some dataset will not be chose in some epoch. Or just train more epoch will fix this issue?


Solution

  • A possible approach is to add weights to the classifier.

    From: https://www.tensorflow.org/tutorials/structured_data/imbalanced_data#class_weights

    The goal is to identify fraudulent transactions, but you don't have very many of those positive samples to work with, so you would want to have the classifier heavily weight the few examples that are available. You can do this by passing Keras weights for each class through a parameter. These will cause the model to "pay more attention" to examples from an under-represented class.

    As your problem is multiclass, you can do it with https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_class_weight.html

    I do this with something like:

    from sklearn.utils import class_weight
    
    class_weights = dict (enumerate (class_weight.compute_class_weight (
      class_weight = 'balanced', 
      classes = available_labels, 
      y = self.dataset.get_split (df, 'train')['label']
    )))
    

    And then:

    history = model.fit (
       ...
       class_weight = class_weights
    )
    

    For my experience, this approach achieves better solutions at the same time it makes training faster.

    In addition, I think that keeping large batches and ensure that data is shuffle are also other good approaches for working with imbalanced data.