pytorchmultilabel-classificationimbalanced-data

Multilabel classification with class imbalance in Pytorch


I have a multilabel classification problem, which I am trying to solve with CNNs in Pytorch. I have 80,000 training examples and 7900 classes; every example can belong to multiple classes at the same time, mean number of classes per example is 130.

The problem is that my dataset is very imbalance. For some classes, I have only ~900 examples, which is around 1%. For “overrepresented” classes I have ~12000 examples (15%). When I train the model I use BCEWithLogitsLoss from pytorch with a positive weights parameter. I calculate the weights the same way as described in the documentation: the number of negative examples divided by the number of positives.

As a result, my model overestimates almost every class… Mor minor and major classes I get almost twice as many predictions as true labels. And my AUPRC is just 0.18. Even though it’s much better than no weighting at all, since in this case the model predicts everything as zero.

So my question is, how do I improve the performance? Is there anything else I can do? I tried different batch sampling techniques (to oversample minority class), but they don’t seem to work.


Solution

  • I would suggest either one of these strategies

    Focal Loss

    A very interesting approach for dealing with un-balanced training data through tweaking of the loss function was introduced in
    Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He and Piotr Dollar Focal Loss for Dense Object Detection (ICCV 2017).
    They propose to modify the binary cross entropy loss in a way that decrease the loss and gradient of easily classified examples while "focusing the effort" on examples where the model makes gross errors.

    Hard Negative Mining

    Another popular approach is to do "hard negative mining"; that is, propagate gradients only for part of the training examples - the "hard" ones.
    see, e.g.:
    Abhinav Shrivastava, Abhinav Gupta and Ross Girshick Training Region-based Object Detectors with Online Hard Example Mining (CVPR 2016)