pythonnumpymatplotlib

Plotting heatmap with gridlines in matplotlib misses gridlines


I am trying to plot a heatmap with gridlines. This is my code (adapted from this post):

# Plot a heatmap with gridlines

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from functional import seq

arr = np.random.randn(3, 20)

plt.tight_layout()
ax = plt.subplot(111)
ax.imshow(arr, cmap='viridis')

xr = ax.get_xlim()
yr = ax.get_ylim()
ax.set_xticks(np.arange(max(xr))-0.5, minor=True)
ax.set_yticks(np.arange(max(yr))-0.5, minor=True)
ax.grid(which='minor', snap=False, color='k', linestyle='-', linewidth=1)
ax.tick_params(which='major', bottom=False, left=False)
ax.tick_params(which='minor', bottom=False, left=False)
ax.set_xticklabels([])
ax.set_yticklabels([])
for spine in ax.spines.values():
    spine.set_visible(False)


plt.show()

I get the following plot (cropped to content):

bad heatplot

There are no vertical grid lines after the 3rd, 8th, and third to last columns. As this script uses minor ticks to plot grid lines, I also thought to print the xticks and make the same plot without hiding the labels:

print('xticks:', np.arange(max(xr))-0.5)

xticks: [-0.5  0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5  9.5 10.5 11.5 12.5
 13.5 14.5 15.5 16.5 17.5 18.5]

bad heatmap with labels

which show that all the necessary ticks are there. What could be the issue here?


Solution

  • I have no idea why this happens. Using the major grid instead fixes this issue:

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    arr = np.random.randn(3, 20)
    
    plt.tight_layout()
    ax = plt.subplot(111)
    ax.imshow(arr, cmap='viridis')
    
    xr = ax.get_xlim()
    yr = ax.get_ylim()
    ax.set_xticks(np.arange(max(xr))-0.5, minor=False)
    ax.set_yticks(np.arange(max(yr))-0.5, minor=False)
    
    ax.grid(which='major', snap=False, color='k', linestyle='-', linewidth=1)
    
    ax.tick_params(which='major', bottom=False, left=False)
    ax.tick_params(which='minor', bottom=False, left=False)
    
    plt.show()