pythonscikit-learnthread-safetypredict

Are predictions on scikit-learn models thread-safe?


Given some classifier (SVC/Forest/NN/whatever) is it safe to call .predict on the same instance concurrently from different threads?

From a distant point of view, my guess is they do not mutate any internal state. But I did not find anything in the docs about it.

Here is a minimal example showing what I mean:

#!/usr/bin/env python3
import threading

from sklearn import datasets
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier

X, y = datasets.load_iris(return_X_y=True)

# Some model. Might be any type, e.g.:
clf = svm.SVC()
clf = RandomForestClassifier(),
clf = MLPClassifier(solver='lbfgs')

clf.fit(X, y)


def use_model_for_predictions():
    for _ in range(10000):
        clf.predict(X[0:1])


# Is this safe?
thread_1 = threading.Thread(target=use_model_for_predictions)
thread_2 = threading.Thread(target=use_model_for_predictions)
thread_1.start()
thread_2.start()

Solution

  • Check out this Q&A, the predict and predict_proba methods should be thread safe as they only call NumPy, they do not affect model itself in any case so answer to your question is yes.

    You can find some info as well in replies here.

    For example in naive bayes the code is following:

    def predict(self, X):
        """
        Perform classification on an array of test vectors X.
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
        Returns
        -------
        C : ndarray of shape (n_samples,)
            Predicted target values for X
        """
        check_is_fitted(self)
        X = self._check_X(X)
        jll = self._joint_log_likelihood(X)
        return self.classes_[np.argmax(jll, axis=1)]
    

    You can see that the first two lines are only checks for input. Abstract method _joint_log_likelihood is the one that interests us, described as:

    @abstractmethod
    def _joint_log_likelihood(self, X):
        """Compute the unnormalized posterior log probability of X
        I.e. ``log P(c) + log P(x|c)`` for all rows x of X, as an array-like of
        shape (n_classes, n_samples).
        Input is passed to _joint_log_likelihood as-is by predict,
        predict_proba and predict_log_proba.
        """
    

    And finally for example for multinominal NB the function looks like (source):

    def _joint_log_likelihood(self, X):
        """
        Compute the unnormalized posterior log probability of X, which is
        the features' joint log probability (feature log probability times
        the number of times that word appeared in that document) times the
        class prior (since we're working in log space, it becomes an addition)
        """
        joint_prob = X * self.feature_log_prob_.T + self.class_log_prior_
        return joint_prob
    

    You can see that there is nothing thread unsafe in predict. Of course you can go through codes and check that for any of those classifiers :)