pythonmatplotlib

PyPlot.Table working with different colspans and rowspans


I'm looking for a way to add columns that can span multiple rows, and rows that can span multiple columns.

I currently have the code below to get the first row in.

# Calculate log-scaled widths
table_widths = [0.001, 0.002, 0.063, 2.0, 63.0, 150.0]
log_table_widths = np.diff(np.log10(table_widths))
log_table_widths = log_table_widths / log_table_widths.sum()

# Normalize widths to sum to 1
log_table_widths = log_table_widths / log_table_widths.sum()

table = ax.table(cellText=[['Clay', 'Silt', 'Sand', 'Gravel', 'Cobbles']], cellLoc='center', loc='bottom', colWidths=log_table_widths)
table_widths = []
table.auto_set_font_size(False)
table.set_fontsize(8)
table.scale(1, 1.5)

To get the following result: Current Result

However I need to add another row to the table that will have some columns span over the next row. While cells on the current row will have to span over multiple columns. Like so: Wanted Result Preferably with the bottom row being the top row, but not a disaster.

I've tried going at this alone and getting the help from GitHub CoPilot and MS CoPilot. However with no luck, the best we could come up with is the following:

# Calculate log-scaled widths
table_widths = [0.001, 0.002, 0.063, 2.0, 63.0, 150.0]
log_table_widths = np.diff(np.log10(table_widths))
log_table_widths = log_table_widths / log_table_widths.sum()

# Normalize widths to sum to 1
log_table_widths = log_table_widths / log_table_widths.sum()

# Create the table
cell_text = [
  ['Clay', 'Silt', 'Fine', 'Medium', 'Coarse', 'Fine', 'Medium', 'Coarse'],
  ['', '', 'Sand', 'Sand', 'Sand', 'Gravel', 'Gravel', 'Gravel'],
]
col_labels = ['Clay', 'Silt', 'Fine', 'Medium', 'Coarse', 'Fine', 'Medium', 'Coarse']
col_widths = [log_table_widths, log_table_widths, log_table_widths/3, log_table_widths/3, log_table_widths/3, log_table_widths/3, log_table_widths/3, log_table_widths/3]

# Add the table to the plot
table = ax.table(cellText=cell_text, colLabels=col_labels, cellLoc='center', loc='bottom', colWidths=col_widths)
table.auto_set_font_size(False)
table.set_fontsize(8)
table.scale(1, 1.5)

# Adjust cell alignment to avoid ambiguity
for key, cell in table.get_celld().items():
  cell.set_text_props(ha='center', va='center')

Giving me the following error:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

With no clue on how to solve it.

For reproducability you can use:

fig, ax = plt.subplots()
fig.set_figwidth(18)
fig.set_figheight(12)
fig.set_dpi(80)
# draw vertical line at: 0.002mm, 0.063mm, 2.0mm, 63mm
ax.axvline(x=0.002, color='red', linestyle='--')
ax.axvline(x=0.063, color='red', linestyle='--')
ax.axvline(x=2.0, color='red', linestyle='--')
ax.axvline(x=63.0, color='red', linestyle='--')
ax.set_xlim(0.001, 150)

# Calculate log-scaled widths
table_widths = [0.001, 0.002, 0.063, 2.0, 63.0, 150.0]
log_table_widths = np.diff(np.log10(table_widths))
log_table_widths = log_table_widths / log_table_widths.sum()

# Normalize widths to sum to 1
log_table_widths = log_table_widths / log_table_widths.sum()

# Create the table
cell_text = [
  ['Clay', 'Silt', 'Fine', 'Medium', 'Coarse', 'Fine', 'Medium', 'Coarse'],
  ['', '', 'Sand', 'Sand', 'Sand', 'Gravel', 'Gravel', 'Gravel'],
]
col_labels = ['Clay', 'Silt', 'Fine', 'Medium', 'Coarse', 'Fine', 'Medium', 'Coarse']
col_widths = [log_table_widths, log_table_widths, log_table_widths/3, log_table_widths/3, log_table_widths/3, log_table_widths/3, log_table_widths/3, log_table_widths/3]

# Add the table to the plot
table = ax.table(cellText=cell_text, colLabels=col_labels, cellLoc='center', loc='bottom', colWidths=col_widths)
table.auto_set_font_size(False)
table.set_fontsize(8)
table.scale(1, 1.5)

# Adjust cell alignment to avoid ambiguity
for key, cell in table.get_celld().items():
  cell.set_text_props(ha='center', va='center')

fig.savefig('fig.png', format='png', bbox_inches='tight')

EDIT: I've managed to get rid of the error. It was caused by defining the col_widths variable where I was filling it with lists instead of their corresponding value. I have now defined it like this, might find a better solution for this later down the line.

col_widths = [log_table_widths[0], log_table_widths[1], log_table_widths[2] / 3, log_table_widths[2] / 3, log_table_widths[2] / 3, log_table_widths[3] / 3, log_table_widths[3] / 3, log_table_widths[3] / 3, log_table_widths[4]]

My table now looks like so: New Result

Although I have not figured out how to merge the rows and cells. I did find this post: Matplotlib table with double headers Where multiple tables are created to show multiple headers. But this, sadly will not work for merging cells in a column and only a row.


Solution

  • Your mistake is that log_table_widths is an array, not a scalar.

    Matplotlib tables don't support merged cells. You can simulate it by removing the cell edges:

    import matplotlib.pyplot as plt
    import matplotlib.table as mtable
    import numpy as np
    
    xlims = [0.001, 150]
    hlines = [0.002, 0.063, 2.0, 63.0]
    
    fig, ax = plt.subplots(figsize=(18,12), dpi=80)
    ax.set(xscale='log', xlim=xlims)
    ax.xaxis.set_visible(False)
    
    for x in hlines:
      ax.axvline(x=x, color='red', linestyle='--')
    
    log_table_widths = np.diff(np.log10(np.r_[xlims[0], hlines, xlims[1]]))
    log_table_widths /= log_table_widths.sum()
    log_table_widths  = np.r_[
         log_table_widths[0],
         log_table_widths[1] / 3,
         log_table_widths[1] / 3,
         log_table_widths[1] / 3,
         log_table_widths[2] / 3,
         log_table_widths[2] / 3,
         log_table_widths[2] / 3,
         log_table_widths[3] / 3,
         log_table_widths[3] / 3,
         log_table_widths[3] / 3,                          
         log_table_widths[4]
         ]
    
    texts = [
        ['', 'Fine', 'Middle', 'Coarse','Fine', 'Middle', 'Coarse', 'Fine', 'Middle', 'Coarse', ''],
        ['Clay', '','Silt','', '', 'Sand', '', '','Gravel','', 'Cobbles'],
    ]
    
    table = ax.table(cellText=texts, cellLoc='center', loc='bottom', colWidths=log_table_widths, fontsize=8)
    table.scale(1, 1.5)
    for c in range(2, 9, 3):
      table.get_celld()[(1,c-1)].visible_edges='BTL'
      table.get_celld()[(1,c)].visible_edges='BT'
      table.get_celld()[(1,c+1)].visible_edges='BTR'
    

    enter image description here

    Update:
    The merge even numbers of cells you'll need to re-position the text. Because of the drawing order of the table, this can't be done by a simple set_position. The only way I found is to translate the text by a transform:

    import matplotlib.pyplot as plt
    import matplotlib.transforms as transforms
    import numpy as np
    
    xlims = [0.001, 150]
    hlines = [0.002, 0.063, 2.0, 63.0]
    
    fig, ax = plt.subplots(figsize=(18,12), dpi=80)
    ax.set(xscale='log', xlim=xlims)
    ax.xaxis.set_visible(False)
    
    for x in hlines:
      ax.axvline(x=x, color='red', linestyle='--')
    
    log_table_widths = np.diff(np.log10(np.r_[xlims[0], hlines, xlims[1]]))
    log_table_widths /= log_table_widths.sum()
    log_table_widths  = np.r_[
         log_table_widths[0],
         log_table_widths[1] / 3,
         log_table_widths[1] / 3,
         log_table_widths[1] / 3,
         log_table_widths[2] / 3,
         log_table_widths[2] / 3,
         log_table_widths[2] / 3,
         log_table_widths[3] / 3,
         log_table_widths[3] / 3,
         log_table_widths[3] / 3,                          
         log_table_widths[4]
         ]
    
    texts = [
        ['', 'Fine', 'Medium', 'Coarse','Fine', 'Medium', 'Coarse', 'Fine', 'Medium', 'Coarse', ''],
        ['Clay', '','Silt','', '', 'Sand', '', '','Gravel','', 'Cobbles'],
    ]
    
    table = ax.table(cellText=texts, cellLoc='center', loc='bottom', colWidths=log_table_widths, fontsize=8)
    table.scale(1, 1.5)
    for c in range(2, 9, 3):
      table.get_celld()[(1,c-1)].visible_edges = 'BTL'
      table.get_celld()[(1,c)].visible_edges = 'BT'
      table.get_celld()[(1,c+1)].visible_edges = 'BTR'
    
    transform = transforms.Affine2D().translate(0, 10)
    for c in (0, 10):
      table.get_celld()[(0,c)].visible_edges = 'TRL'
      table.get_celld()[(1,c)].visible_edges = 'BRL'
      table.get_celld()[(1,c)].get_text().set_transform(transform)
    

    enter image description here