pythongoogle-colaboratory

I'm trying to run the following code but i keep coming with an error in the last line


This is the code I am running:

    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.model_selection import train_test_split, cross_val_score
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.naive_bayes import GaussianNB, MultinomialNB
    from sklearn.metrics import accuracy_score
    from sklearn.neighbors import KNeighborsClassifier   
    from sklearn.neural_network import MLPClassifier
    from sklearn.metrics import precision_score, recall_score,auc
    from sklearn.metrics import roc_curve,roc_auc_score, plot_roc_curve

This was the error:

ImportError                               Traceback (most recent call      last)
<ipython-input-3-d1b430d75826> in <cell line: 7>()
      5 from sklearn.tree import DecisionTreeClassifier
      6 from sklearn.naive_bayes import GaussianNB, MultinomialNB
----> 7 from sklearn.metrics import accuracy_score, precision_score,  recall_score, roc_curve, roc_auc_score, plot_roc_curve
      8 from sklearn.neighbors import KNeighborsClassifier
      9 from sklearn.neural_network import MLPClassifier

ImportError: cannot import name 'plot_roc_curve' from 'sklearn.metrics' (/usr/local/lib/python3.10/dist-packages/sklearn/metrics/__init__.py)

Please let me know if my code is wrong or it is missing anything?


Solution

  • The error is occurring because plot_roc_curve has been deprecated in scikit-learn version 1.0 and removed in version 1.2. You can instead use RocCurveDisplay from the sklearn.metrics module since Google Colab seems to be using 1.3.2 right now:

    # import sklearn
    # print(sklearn.__version__)  # 1.3.2
    from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_curve, roc_auc_score, RocCurveDisplay
    

    And for plotting the ROC curve, you can use RocCurveDisplay like this:

    RocCurveDisplay.from_estimator(your_model, your_X_test, your_y_test)
    plt.show()