pythonmatplotlibstem

Set custom and changing baseline to stem plot in Matplotlib


I'm triying to make a figure where the stem plot has the baseline on the data of dataframe_3_merged['TOTAL'].

import numpy as np
from eurostatapiclient import EurostatAPIClient
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import seaborn as sns
import pandas as pd
#Set versions and formats, so far only the ones used here are availeable and call client
VERSION = 'v2.1'
FORMAT = 'json'
LANGUAGE = 'en'
client = EurostatAPIClient(VERSION, FORMAT, LANGUAGE)



dataframe_3_query_total = 'ilc_peps01?precision=1&sex=T&geo=AT&geo=BE&geo=BG&geo=CH&geo=CY&geo=CZ&geo=DK&geo=EA19&geo=EE&geo=EL&geo=ES&geo=EU28&geo=FI&geo=FR&geo=HR&geo=HU&geo=IE&geo=IS&geo=IT&geo=LT&geo=LU&geo=LV&geo=ME&geo=MK&geo=MT&geo=NL&geo=NO&geo=PL&geo=PT&geo=RO&geo=RS&geo=SE&geo=SI&geo=SK&geo=TR&geo=UK&unit=PC&unitLabel=label&time=2018&age=TOTAL'
dataframe_3_query_urb = 'ilc_peps13?precision=1&deg_urb=DEG1&deg_urb=DEG2&deg_urb=DEG3&geo=AT&geo=BE&geo=BG&geo=CH&geo=CY&geo=CZ&geo=DE&geo=DK&geo=EA19&geo=EE&geo=EL&geo=ES&geo=EU28&geo=FI&geo=FR&geo=HR&geo=HU&geo=IE&geo=IS&geo=IT&geo=LT&geo=LU&geo=LV&geo=MK&geo=MT&geo=NL&geo=NO&geo=PL&geo=PT&geo=RO&geo=RS&geo=SE&geo=SI&geo=SK&geo=UK&unit=PC&unitLabel=label&time=2018'


dataframe_3_total = client.get_dataset(dataframe_3_query_total).to_dataframe().pivot(index = 'geo',columns = 'age',values = 'values')
dataframe_3_urb =client.get_dataset(dataframe_3_query_urb).to_dataframe().pivot(index = 'geo',columns = 'deg_urb',values = 'values')

dataframe_3_merged = dataframe_3_total.join(dataframe_3_urb).dropna()


fig, ax = plt.subplots(figsize=(15, 4))
plt.ylim(0,51)

x = range(0,32,1)
stem_1 =plt.stem(x,dataframe_3_merged['DEG1'])
stem_2=plt.stem(x, dataframe_3_merged['DEG2'])
stem_3=plt.stem(x, dataframe_3_merged['DEG3'])
plt.setp(stem_2, color = 'r')
plt.setp(stem_3, color = 'g')

scatterplot= sns.scatterplot(x=dataframe_3_merged.index, #We draw the scatterplot and specify the arguments
                             y = dataframe_3_merged['TOTAL'],
                             ax=ax ,
                             s = 100 ,
                             legend = False,
                             marker="_",
                             color = 'b')

The goal is to have a plot similar to this image:
1
I tried to use the list dataframe_3_merged['TOTAL'] as the parameter in the bottom argument of plt.stem but I have this traceback: ValueError: setting an array element with a sequence. Thank you for your help!


Solution

  • You could replace each stem plot by a scatter plot and a plot of vertical lines (plt.vlines). Setting the zorder=0 ensures the lines are drawn behind the dots.

    from matplotlib import pyplot as plt
    import numpy as np
    import pandas as pd
    
    names = ['hydrogen', 'helium', 'lithium', 'beryllium', 'boron', 'carbon', 'nitrogen', 'oxygen', 'fluorine', 'neon', 'sodium', 'magnesium', 'aluminium', 'silicon', 'phosphorus', 'sulphur', 'chlorine', 'argon', 'potassium', 'calcium', 'scandium', 'titanium', 'vanadium', 'chromium', 'manganese', 'iron', 'cobalt', 'nickel', 'copper', 'zinc', 'gallium', 'germanium', 'arsenic', 'selenium', 'bromine', 'krypton']
    
    N = len(names)
    df = pd.DataFrame({'Deg1': 35 + np.random.normal(size=N).cumsum(),
                       'Deg2': 25 + np.random.normal(size=N).cumsum(),
                       'Deg3': 15 + np.random.normal(size=N).cumsum()},
                      index=names)
    df['Total'] = df.mean(axis=1)
    
    for deg, color, label in zip(['Deg1', 'Deg2', 'Deg3'], ['tomato', 'orange', 'palegreen'],
                                 ['label1', 'label2', 'label3']):
        plt.vlines(df.index, df[deg], df['Total'], lw=0.2, color='k', zorder=0)
        plt.scatter(df.index, df[deg], marker='o', color=color, label=label)
    plt.scatter(df.index, df['Total'], marker='_', color='deepskyblue', s=100)
    plt.xticks(rotation='vertical')
    plt.ylim(0, 51)
    plt.margins(x=0.02)
    plt.legend(ncol=3, bbox_to_anchor=(0.5, -0.4), loc='upper center')
    plt.grid(True, axis='y')
    plt.tick_params(length=0)
    for where in ['top', 'left', 'right']:
        plt.gca().spines[where].set_visible(False)
    plt.tight_layout()
    plt.show()
    

    example plot