pythonscikit-learnpredictmulticlass-classification

sklearn how to mock prediction of zero probabilities of the labels that were not present in the training set?


I hope I titled this correctly.

I run a multiclass classification (3 classes). And score it with ROC AUC.

make_scorer(roc_auc_score, needs_proba=True, average="macro", multi_class='ovo', labels=[-1, 0, 1])

I split the train/test data with a time series splitter and can not reshuffle the order of data (no stratify parameter).

One of the splits does not contain the '0' label in the train data. So .fit function only sees 2 labels and thus the predict_proba function only has a 2 column output.

I get this ValueError Number of given labels, 3, not equal to the number of columns in 'y_score', 2 when I run my multiclass ROC AUC scoring.

I thought that I would be ok with the function predicting a zero probability for the case of '0' class. So I want to add a mock probability prediction. Is there any way to do it in the standard library?

Any other recommendations? I thought of 1) wrapping the predict_proba, to add missing probability column 2) changing the time series split so that if train data contains only 2 classes - take more train data.


Solution

  • I'll post what I did eventually. I did a wrapper function that returns a child class of an estimator (eg LogisticRegression) with augmented predict_proba and augmented fit function. The fit function saves which labels it has seen in y_train. the predict_proba function fills with zero the columns corresponding to labels that were not present in y_train but present in labels.

    def predict_proba_wrapper(method: classmethod, labels: list):
        """Add zeros to the predict_proba columns if labels not present in y_true."""
    
        @wraps(method)
        def wrapper(self, *args, **kwargs ):
            # find labels indices not in y_train
            indices_to_fill = []
            for i, label in enumerate(labels):
                if label not in self.labels_seen:
                    indices_to_fill.append(i)
            # call method
            y_pred = method(self, *args, **kwargs)
            # fill zeros
            if not isinstance(y_pred, np.ndarray):
                y_pred_np = np.array(y_pred)
            else:
                y_pred_np = y_pred
    
            for i in indices_to_fill:
                y_pred_np = np.insert(y_pred_np, i, 0., axis=1)
    
            if isinstance(y_pred, np.ndarray):
                return y_pred_np
            elif isinstance(y_pred, pd.DataFrame):
                return pd.DataFrame(y_pred_np, index=y_pred.index)
            elif isinstance(y_pred, pd.Series):
                return pd.Series(y_pred_np, index=y_pred.index)
            elif isinstance(y_pred, list):
                return y_pred_np.tolist()
            else:
                raise ValueError(f"y_pred type {type(y_pred)} not supported")
    
        return wrapper
    
    def fit_wrapper(method: classmethod, labels: list):
        """Add labels seen to the class."""
    
        @wraps(method)
        def wrapper(self, *args, **kwargs ):
            res = method(self, *args, **kwargs)
            if len(args) >= 2:
                y = args[1]
            else:
                y = kwargs["y"]
            if isinstance(y, np.ndarray):
                self.labels_seen = list(np.unique(y))
            elif isinstance(y, pd.DataFrame):
                self.labels_seen = list(y.iloc[:, 0].unique())
            elif isinstance(y, pd.Series):
                self.labels_seen = list(y.unique())
            elif isinstance(y, list):
                self.labels_seen = list(set(y))
            else:
                raise ValueError(f"y type {type(y)} not supported")
            if hasattr(self, "classes_"):
                if isinstance(self.classes_, np.ndarray):
                    self.classes_ = np.array(labels)
                elif isinstance(self.classes_, pd.DataFrame):
                    self.classes_ = pd.DataFrame(labels)
                elif isinstance(self.classes_, pd.Series):
                    self.classes_ = pd.Series(labels)
                elif isinstance(self.classes_, list):
                    self.classes_ = labels
                else:
                    raise ValueError(f"y type {type(y)} not supported")
    
            return res
    
        return wrapper
    
    
    def class_child_with_wrapped_methods(class_: Type, method_names: List[str], wrappers: List[callable]):
        """Return a new class with a method wrapped by method wrapper."""
        new_class = type(class_.__name__ + "Wrapped", (class_,), {})
        for i, method_name in enumerate(method_names):
            setattr(new_class, method_name, wrappers[i](getattr(new_class, method_name)))
        return new_class
    
    
    def wrap_fit_predict_proba(class_: Type, labels: list):
        """Return a new class with predict_proba wrapped by predict_proba_wrapper."""
        return class_child_with_wrapped_methods(
            class_,
            ["predict_proba", "fit"],
            [
                lambda x: predict_proba_wrapper(x, labels),
                lambda x: fit_wrapper(x, labels)
            ]
        )
        
    CLASSIFIERS = [
        wrap_fit_predict_proba(LogisticRegression, labels[-1,0,1]),
        wrap_fit_predict_proba(ExtraTreesClassifier, labels=[-1,0,1]),
    ]