pythonmatplotlibpivotxticks

For multi-index columns in pandas dataframe, how can I group index of a particular level value for visualization in Python?


I have a pandas dataframe which is basically a pivot table.

df.plot(kind = "bar",stacked = True) results in following plot. The labels in x-axis are congested as shown. enter image description here

In Excel I can group the first index value for Scenarios pes, tes and des are clear and distinct as shown: enter image description here

How can I create similar labels in x-axis using matplotlib in Python?

Here is a sample dataset with minimal code:

 dict = {'BatteryStorage': {('des-PDef3', 'Central Africa'): 0.0,
      ('des-PDef3', 'Eastern Africa'): 2475.9,
      ('des-PDef3', 'North Africa'): 98.0,
      ('des-PDef3', 'Southern Africa'): 124.0,
      ('des-PDef3', 'West Africa'): 1500.24,
      ('pes-PDef3', 'Central Africa'): 0.0,
      ('pes-PDef3', 'Eastern Africa'): 58.03,
      ('pes-PDef3', 'North Africa'): 98.0,
      ('pes-PDef3', 'Southern Africa'): 124.0,
      ('pes-PDef3', 'West Africa'): 0.0,
      ('tes-PDef3', 'Central Africa'): 0.0,
      ('tes-PDef3', 'Eastern Africa'): 1175.86,
      ('tes-PDef3', 'North Africa'): 98.0,
      ('tes-PDef3', 'Southern Africa'): 124.0,
      ('tes-PDef3', 'West Africa'): 0.0},
     'Biomass PP': {('des-PDef3', 'Central Africa'): 44.24,
      ('des-PDef3', 'Eastern Africa'): 1362.4,
      ('des-PDef3', 'North Africa'): 178.29,
      ('des-PDef3', 'Southern Africa'): 210.01999999999998,
      ('des-PDef3', 'West Africa'): 277.4,
      ('pes-PDef3', 'Central Africa'): 44.24,
      ('pes-PDef3', 'Eastern Africa'): 985.36,
      ('pes-PDef3', 'North Africa'): 90.93,
      ('pes-PDef3', 'Southern Africa'): 144.99,
      ('pes-PDef3', 'West Africa'): 130.33,
      ('tes-PDef3', 'Central Africa'): 44.24,
      ('tes-PDef3', 'Eastern Africa'): 1362.4,
      ('tes-PDef3', 'North Africa'): 178.29,
      ('tes-PDef3', 'Southern Africa'): 210.01999999999998,
      ('tes-PDef3', 'West Africa'): 277.4}}

df = pd.DataFrame.from_dict(dict)
df.plot(kind = "bar",stacked = True)
plt.show()

Solution

  • I have been struggling a bit with finding a way to draw lines outside the plot area but found a creative solution in this previous thread: How to draw a line outside of an axis in matplotlib (in figure coordinates). Thanks to the author for the solution once again!

    My proposed solution for the problem is the following (see the explanation of distinct parts in the code):

    import pandas as pd
    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D
    
    
    dict = {'BatteryStorage': {('des-PDef3', 'Central Africa'): 0.0,
          ('des-PDef3', 'Eastern Africa'): 2475.9,
          ('des-PDef3', 'North Africa'): 98.0,
          ('des-PDef3', 'Southern Africa'): 124.0,
          ('des-PDef3', 'West Africa'): 1500.24,
          ('pes-PDef3', 'Central Africa'): 0.0,
          ('pes-PDef3', 'Eastern Africa'): 58.03,
          ('pes-PDef3', 'North Africa'): 98.0,
          ('pes-PDef3', 'Southern Africa'): 124.0,
          ('pes-PDef3', 'West Africa'): 0.0,
          ('tes-PDef3', 'Central Africa'): 0.0,
          ('tes-PDef3', 'Eastern Africa'): 1175.86,
          ('tes-PDef3', 'North Africa'): 98.0,
          ('tes-PDef3', 'Southern Africa'): 124.0,
          ('tes-PDef3', 'West Africa'): 0.0},
         'Biomass PP': {('des-PDef3', 'Central Africa'): 44.24,
          ('des-PDef3', 'Eastern Africa'): 1362.4,
          ('des-PDef3', 'North Africa'): 178.29,
          ('des-PDef3', 'Southern Africa'): 210.01999999999998,
          ('des-PDef3', 'West Africa'): 277.4,
          ('pes-PDef3', 'Central Africa'): 44.24,
          ('pes-PDef3', 'Eastern Africa'): 985.36,
          ('pes-PDef3', 'North Africa'): 90.93,
          ('pes-PDef3', 'Southern Africa'): 144.99,
          ('pes-PDef3', 'West Africa'): 130.33,
          ('tes-PDef3', 'Central Africa'): 44.24,
          ('tes-PDef3', 'Eastern Africa'): 1362.4,
          ('tes-PDef3', 'North Africa'): 178.29,
          ('tes-PDef3', 'Southern Africa'): 210.01999999999998,
          ('tes-PDef3', 'West Africa'): 277.4}}
    
    
    
    df = pd.DataFrame.from_dict(dict)
    df.plot(kind = "bar",stacked = True)
    
    region_labels = [idx[1] for idx in df.index]   #deriving the part needed for the x-labels from dict
    
    plt.tight_layout()      #necessary for an appropriate display
    
    plt.legend(loc='center left', fontsize=8, frameon=False, bbox_to_anchor=(1, 0.5))   #placing lagend outside the plot area as in the Excel example
    
    ax = plt.gca()
    ax.set_xticklabels(region_labels, rotation=90)
    
    #coloring labels for easier interpretation
    for i, label in enumerate(ax.get_xticklabels()):
        #print(i)
        if i <= 4:
            label.set_color('red')  #set favoured colors here
        if 9 >= i > 4:
            label.set_color('green') 
        if i > 9:
            label.set_color('blue')  
    
    
    plt.text(1/6, -0.5, 'des', fontweight='bold', transform=ax.transAxes, ha='center', color='red')     #adding labels outside the plot area, representing the 'region group code'
    plt.text(3/6, -0.5, 'pes', fontweight='bold', transform=ax.transAxes, ha='center', color='green')   #keep coloring respective to labels
    plt.text(5/6, -0.5, 'des', fontweight='bold', transform=ax.transAxes, ha='center', color='blue')
    plt.text(5/6, -0.6, 'b', color='white', transform=ax.transAxes, ha='center')                         #phantom text to trick `tight_layout` thus making space for the texts above
    
    
    ax2 = plt.axes([0,0,1,1], facecolor=(1,1,1,0))          #for adding lines (i.e., brackets) outside the plot area, we create new axes
    
    
    #creating the first bracket
    x_start = 0 + 0.015
    x_end = 1/3 - 0.015
    y = -0.42 
    
    bracket1 = [
        Line2D([x_start, x_start], [y, y - 0.02], transform=ax.transAxes, color='black', lw=1.5),
        Line2D([x_start, x_end], [y - 0.02, y - 0.02], transform=ax.transAxes, color='black', lw=1.5),
        Line2D([x_end, x_end], [y - 0.02, y], transform=ax.transAxes, color='black', lw=1.5),
    ]
    
    for line in bracket1:
        ax2.add_line(line)
    
    
    #second bracket
    x_start = 1/3 + 0.015
    x_end = 2/3 - 0.015 
    
    bracket2 = [
        Line2D([x_start, x_start], [y, y - 0.02], transform=ax.transAxes, color='black', lw=1.5),
        Line2D([x_start, x_end], [y - 0.02, y - 0.02], transform=ax.transAxes, color='black', lw=1.5),
        Line2D([x_end, x_end], [y - 0.02, y], transform=ax.transAxes, color='black', lw=1.5),
    ]
    
    for line in bracket2:
        ax2.add_line(line)
    
    
    #third bracket
    x_start = 2/3 + 0.015
    x_end = 1 - 0.015 
    
    bracket3 = [
        Line2D([x_start, x_start], [y, y - 0.02], transform=ax.transAxes, color='black', lw=1.5),
        Line2D([x_start, x_end], [y - 0.02, y - 0.02], transform=ax.transAxes, color='black', lw=1.5),
        Line2D([x_end, x_end], [y - 0.02, y], transform=ax.transAxes, color='black', lw=1.5),
    ]
    
    for line in bracket3:
        ax2.add_line(line)
    
    
    ax2.axis("off")     #turn off axes for the new axes
    
    plt.tight_layout()
    plt.show()
    

    Resulting in the following plot:

    enter image description here