pythonmatplotlibsubplotimshow

Turn off axes in subplots


I have the following code:

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.cm as cm

img = mpimg.imread("lena.jpg")

fig, axs = plt.subplots(2, 2)
axs[0,0].imshow(img, cmap = cm.Greys_r)
axs[0,0].set_title("Rank = 512")

rank = 128
new_img = prune_matrix(rank, img)
axs[0,1].imshow(new_img, cmap = cm.Greys_r)
axs[0,1].set_title("Rank = %s" %rank)

rank = 32
new_img = prune_matrix(rank, img)
axs[1,0].imshow(new_img, cmap = cm.Greys_r)
axs[1,0].set_title("Rank = %s" %rank)

rank = 16
new_img = prune_matrix(rank, img)
axs[1,1].imshow(new_img, cmap = cm.Greys_r)
axs[1,1].set_title("Rank = %s" %rank)

plt.show()

However, the result is pretty ugly because of the values on the axes:

2x2 subplots

How can I turn off axes values for all subplots simultaneously?

How to remove axis, legends, and white padding doesn't work because I don't know how to make it work with subplots.


Solution

  • import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    import matplotlib.cm as cm
    import matplotlib.cbook as cbook  # used for matplotlib sample image
    
    # load readily available sample image
    with cbook.get_sample_data('grace_hopper.jpg') as image_file:
        img = plt.imread(image_file)
    
    # read a local file
    # img = mpimg.imread("file.jpg")
    
    fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(8, 8), tight_layout=True)
    axs[0, 0].imshow(img, cmap=cm.Greys_r)
    axs[0, 0].set_title("Rank = 512")
    axs[0, 0].axis("off")
    
    axs[0, 1].imshow(img, cmap=cm.Greys_r)
    axs[0, 1].set_title("Rank = %s" % 128)
    axs[0, 1].axis("off")
    
    axs[1, 0].imshow(img, cmap=cm.Greys_r)
    axs[1, 0].set_title("Rank = %s" % 32)
    axs[1, 0].axis("off")
    
    axs[1, 1].imshow(img, cmap=cm.Greys_r)
    axs[1, 1].set_title("Rank = %s" % 16)
    axs[1, 1].axis("off")
    
    plt.show()
    

    enter image description here

    Note: To turn off only the x or y axis you can use set_visible() e.g.:

    axs[0, 0].xaxis.set_visible(False) # Hide only x axis
    

    fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(8, 8), tight_layout=True)
    
    # convert the 2d array to 1d, which removes the need to iterate through i and j
    axs = axs.flat
    ranks = [512, 128, 32, 16]
    
    # iterate through each Axes with the associate rank
    for ax, rank in zip(axs, ranks):
    
        ax.imshow(img, cmap=cm.Greys_r)
        ax.set_title(f'Rank = {rank}')
        ax.axis('off')
    
    plt.show()