pythonpandasseabornscatter-plotseaborn-objects

Plot regression confidence interval using seaborn.objects


How can I use the objects API to plot a regression line with confidence limits (like sns.regplot / sns.lmplot)?

Based on the so.Band fmri example, I thought it would be something like this:

import seaborn as sns
import seaborn.objects as so

df = sns.load_dataset("tips")

(so.Plot(df, x="total_bill", y="tip")
   .add(so.Dots())
   .add(so.Line(), so.PolyFit()) # regression line
   .add(so.Band(), so.Est())     # confidence band
)

But I'm getting an unexpected result (left). I'm expecting something like the regplot band (right).

current and expected output


Solution

  • seaborn.object can only aggregate provided data (e.g. key/value pairs for histograms). Your use-case is more complex, and you need to create bootstrapped regression data yourself. Other APIs are capable of doing this on the fly, but object can not as of now (see the note).

    Consider this code preparing bootstrapped predictions

    import seaborn as sns
    import seaborn.objects as so
    import pandas as pd
    import numpy as np
    from sklearn.linear_model import LinearRegression
    from sklearn.ensemble import BaggingRegressor
    
    # data
    df = sns.load_dataset("tips")
    
    # bootstrap regression: tip ~ bill
    X = df['total_bill'].values.reshape(-1,1)
    y = df['tip'].values
    model = BaggingRegressor(LinearRegression(),
                             n_estimators=100,
                             max_samples=1.0, # 100% of the dataset
                             bootstrap=True)
    
    model.fit(X, y)
    bootstrapped_preds = pd.DataFrame([m.predict(X) for m in model.estimators_]).T
    
    # combine bootstrapped preds and original data
    df_pred = pd.concat([df[['total_bill']], pd.DataFrame(bootstrapped_preds)], axis=1)
    df_pred = pd.melt(df_pred, id_vars='total_bill', value_vars=list(range(50)), value_name='tip')
    df_pred['type'] = 'pred'
    df_pred = pd.concat([df_pred,df[['total_bill','tip']]],axis=0)
    df_pred['type'] = df_pred['type'].fillna('observed')
    

    and this one plotting 2 x standard deviation around predictions.

    # plot
    import seaborn as sns
    import seaborn.objects as so
    
    (so.Plot(df_pred, x="total_bill", y="tip", color="type")
        .add(so.Band(), so.Est(errorbar=('sd',2), n_boot=1000))
        .add(so.Dot(), so.Agg())
    )
    

    which generates this figure enter image description here

    And here is the full notebook.


    Seaborn's object API is unmature as of now (although very interesting!) and does not support some requested functionalities yet. The maintainers also don't accept some PRs going too far, promising more features once the code stabilizes. Quoting the git discussion (Feb'23):

    I believe PolyFit is basically a placeholder now, to provide basic functionality, but may change in the future, which is why the doc is incomplete. As for having more objects, I think the answer is in #3133 : it is definitely intended to have more objects, but probably once the API is more stable. Adding the objects now would mean more maintenance and discourage changing the API when it may be necessary. Having them in the Discussions seems a good compromise to me : accessible to other users, but without the pressure of maintenance.

    My advice is to not try to find hacks, but simply use the standard API and keep thumbs up for improvements in future releases of seaborn.objects.