I am training an online-leaning SVM Classifier using SGDClassifier
in sklearn
. I learnt that it is possible using partial_fit
.
My model definition is :
model = SGDClassifier(loss="hinge", penalty="l2", alpha=0.0001, max_iter=3000, tol=1e-3, shuffle=True, verbose=0, learning_rate='invscaling', eta0=0.01, early_stopping=False)
and it is created only the first time.
To test it, I first trained my classifier model 1 on the entire data using fit
and got 87% model accuracy (using model.score(X_test, y_test)
). Then, to demonstrate online training, I broke the same data into 4 sets and then fed all the 4 parts in 4 different run using partial_fit
. This was model 2.
But in this case, my accuracy dropped as: 87.9 -> 98.89 -> 47.7 -> 29.4.
What could be cause for this ?
This is how I got over it.
Usually, partial_fit
has seen to be prone to reduction or fluctuation in accuracy. To some extent, this can be slightly mitigated by shuffling and provinding only small fractions of the entire dataset. But, for larger data, online training only seems to give reducing accuracies, with SGDClassifier/SVM Classifier.
I tried to experiment with it and discovered that using a low learning rate can help us sometimes. The rough analogy is, on training the same model on large data repeateadly, leads to the model forgetting what it learnt from the previous data. So, using a tiny learning rate slows down the rate of learning as well as forgetting!
Rather than manually providing a rate, we can use adaptive
learning rate functionality provided by sklearn
.
model = SGDClassifier(loss="hinge", penalty="l2", alpha=0.0001, max_iter=3000, tol=None, shuffle=True, verbose=0, learning_rate='adaptive', eta0=0.01, early_stopping=False)
This is described in the [scikit docs] as:
‘adaptive’: eta = eta0, as long as the training keeps decreasing. Each time n_iter_no_change consecutive epochs fail to decrease the training loss by tol or fail to increase validation score by tol if early_stopping is True, the current learning rate is divided by 5.
Initially, with each round of new data, we would get the drop in accuracies as:
7.9 -> 98.89 -> 47.7 -> 29.4
Now, we get better results with 100% accuracy, although there is high scope of over-fitting due to increased epochs. I have tried to demonstrate the observations in this kaggle notebook