matplotlibscikit-learn

Use and save plots created by scikit learn RocCurveDisplay with matplotlib's functions


I have created a plot with scikit's RocCurveDisplay class and I cannot figure out how to save this plot. I cannot find a way to use matplotlib's functions like savefig() somehow. I would also like to use the plots of RocCurveDisplay in matplotlib's subplots(). Is this possible? Could someone help, and if possible elaborate a little?


Solution

  • The RocCurveDisplay's plot uses a set of subplots (1x1) if no axes is provided.

    If you need more than one, you can simply build your own and pass each ax to the plot :

    import matplotlib.pyplot as plt
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 5))
    
    # used inputs from the docs (Visualizations with Display Objects)
    
    roc_display.plot(ax=ax1) # RocCurveDisplay's plot
    pr_display.plot(ax=ax2) # the other plots (eventually)
    
    # to save an image for example
    plt.savefig("display_object_skl.png", dpi=300, bbox_inches="tight")
    
    plt.show()

    enter image description here