pythonplotlyplotly.graph-objects

Putting Linear Trendline on a Plotly Subplot


I wanted to know if there was an easier way I could put a linear regression line on a plotly subplot. The code I made below does not appear to be efficient and it makes it difficult to add annotations to the graph for the linear trendlines, which I want placed on the graph. Furthermore, it is hard to make axes and titles with this code.

I was wondering if there was a way I could create a go.Figure and somehow put it on the subplot. I have tried that, but plotly will only allow me to put the data from the figure on the subplot rather than the actual Figure, so I lose the title, axis, and trendline information. In addition, the trendline is hidden on the graphs because the scatterplot is overlaid on top of it. I tried changing how the data was displayed with data=(data[1],data[0]), but that did not work.

Basically, I want to know if there is a more efficient way of putting a trendline on the scatter plots than I pursued, so I can make it easier to set axes, set the graph size, create legends, etc, since it is difficult to work with what I coded Graph from my code.

sheets_dict=pd.ExcelFile('10.05.22_EMS172LabReport1.xlsx')
sheets_list=np.array(sheets_dict.sheet_names[2:])

fig=make_subplots(rows=7,cols=1)

i=0
for name in sheets_list:
    df=sheets_dict.parse(name)
    df.columns=df.columns.str.replace(' ','')
    df=df.drop(df.index[0])
    slope,y_int=np.polyfit(df.CURR1,df.VOLT1,1)
    LR="Linear Fit: {:,.3e}x + {:,.3e}".format(slope,y_int)
    rmse=np.sqrt(sum(slope*df.CURR1+y_int-df.VOLT1)**2)
    df['Best Fit']=slope*df.CURR1+y_int
    i+=1
    fig.add_trace(
        go.Scatter(name='Best Fit Line'+" ± {:,.3e}V".format(rmse),x=df['CURR1'],y=df['Best Fit'],
                   mode='lines',line_color='red',line_width=2),row=i, col=1)
    fig.add_trace(
        go.Scatter(name='Voltage',x=df['CURR1'],y=df['VOLT1'],mode='markers'),
        row=i, col=1)
#     fig.data = (fig.data[1],fig.data[0])

fig.show()

Solution

  • Trendlines are implemented in plotly.express with extensive functionality. See here. It is possible to create a subplot using that graph data, but I have created a subplot with a graph object to take advantage of your current code.

    Since you did not provide specific data, I used the example data in ref. It is a data frame showing the rate of change in stock prices for several companies. It is in the form of a trend line added to it.

    As for the graph, I have changed the height because a subplot requires height. The addition of axis labels for each subplot is specified in a matrix. If you need axis titles for all subplots, add them. Also, as a customization of the legend, we have grouped A group for the torrent lines and a group for the rate of change. As an example of the annotations, the slope values are set to 0 on the x-axis of each subplot and the y-axis is set to the position of the maximum value of each value.

    import plotly.express as px
    import plotly.graph_objects as go
    import numpy as np
    
    df = px.data.stocks()
    
    df.head()
    date    GOOG    AAPL    AMZN    FB  NFLX    MSFT
    0   2018-01-01  1.000000    1.000000    1.000000    1.000000    1.000000    1.000000
    1   2018-01-08  1.018172    1.011943    1.061881    0.959968    1.053526    1.015988
    2   2018-01-15  1.032008    1.019771    1.053240    0.970243    1.049860    1.020524
    3   2018-01-22  1.066783    0.980057    1.140676    1.016858    1.307681    1.066561
    4   2018-01-29  1.008773    0.917143    1.163374    1.018357    1.273537    1.040708
    
    from plotly.subplots import make_subplots
    
    fig = make_subplots(rows=6,cols=1, subplot_titles=df.columns[1:].tolist())
    
    for i,c in enumerate(df.columns[1:]):
        dff = df[[c]].copy()
        slope,y_int=np.polyfit(dff.index, dff[c], 1)
        LR="Linear Fit: {:,.3e}x + {:,.3e}".format(slope,y_int)
        rmse=np.sqrt(sum(slope*dff.index+y_int-df[c])**2)
        dff['Best Fit'] = slope*df.index+y_int
        fig.add_trace(go.Scatter(
            name='Best Fit Line'+" ± {:,.3e}V".format(rmse),
            x=dff.index,
            y=dff['Best Fit'],
            mode='lines',
            line_color='blue',
            line_width=2,
            legendgroup='group1',
            legendgrouptitle_text='Trendline'), row=i+1, col=1)
        fig.add_trace(go.Scatter(
            x=dff.index,
            y=dff[c],
            legendgroup='group2',
            legendgrouptitle_text='Rate of change',
            mode='markers+lines', name=c), row=i+1, col=1)
        fig.add_annotation(x=0.1,
                           y=dff[c].max(),
                           xref='x',
                           yref='y',
                           text='{:,.3e}'.format(rmse),
                           showarrow=False,
                           yshift=5, row=i+1, col=1)
    
    fig.update_layout(autosize=True, height=800, title_text="Stock and Trendline")
    fig.update_xaxes(title_text="index", row=6, col=1)
    fig.update_yaxes(title_text="Rate of change", row=3, col=1)
    
    fig.show()
    

    enter image description here