pythonmachine-learningscikit-learnclassificationlogistic-regression

Controlling the threshold in Logistic Regression in Scikit Learn


I am using the LogisticRegression() method in scikit-learn on a highly unbalanced data set. I have even turned the class_weight feature to auto.

I know that in Logistic Regression it should be possible to know what is the threshold value for a particular pair of classes.

Is it possible to know what the threshold value is in each of the One-vs-All classes the LogisticRegression() method designs?

I did not find anything in the documentation page.

Does it by default apply the 0.5 value as threshold for all the classes regardless of the parameter values?


Solution

  • Yes, Sci-Kit learn is using a threshold of P>=0.5 for binary classifications. I am going to build on some of the answers already posted with two options to check this:

    One simple option is to extract the probabilities of each classification using the output from model.predict_proba(test_x) segment of the code below along with class predictions (output from model.predict(test_x) segment of code below). Then, append class predictions and their probabilities to your test dataframe as a check.

    As another option, one can graphically view precision vs. recall at various thresholds using the following code.

    ### Predict test_y values and probabilities based on fitted logistic 
    regression model
    
    pred_y=log.predict(test_x) 
    
    probs_y=log.predict_proba(test_x) 
      # probs_y is a 2-D array of probability of being labeled as 0 
      # (first column of array) vs 1 (2nd column in array)
    
    
    from sklearn.metrics import precision_recall_curve
    precision, recall, thresholds = precision_recall_curve(test_y, probs_y[:, 
    1]) 
       #retrieve probability of being 1(in second column of probs_y)
    pr_auc = metrics.auc(recall, precision)
    
    plt.title("Precision-Recall vs Threshold Chart")
    plt.plot(thresholds, precision[: -1], "b--", label="Precision")
    plt.plot(thresholds, recall[: -1], "r--", label="Recall")
    plt.ylabel("Precision, Recall")
    plt.xlabel("Threshold")
    plt.legend(loc="lower left")
    plt.ylim([0,1])