Suppose I have 3 directories of .jpg files: dataset 1, dataset 2, dataset 3.
I would like to make a 5 by 3 subplots using matplotlib. For each row, the subplot shows the data from dataset 1, dataset 2 and dataset 3 in order. The expected format is like this:
plot1, plot2, plot3,
plot4.......
plot13, plot14, plot15.
How should I do that?
something like this:
plt.figure(figsize=(10, 10))
for data1, data2, data3 in dataset1, dataset2, dataset3"
....
Path(...).glob()
from pathlib
to find all of the image paths in each directory, and unpack them in a list comprehension.matplotlib.pyplot.imread
and matplotlib.axes.Axes.imshow
are used to read and show the images, respectively.python 3.12.0
, matplotlib 3.8.1
import matplotlib.pyplot as plt
from pathlib import Path
# create a list of directories
dirs = ['../Pictures/dataset1', '../Pictures/dataset2', '../Pictures/dataset3']
# extract the image paths into a list
files = [f for dir_ in dirs for f in list(Path(dir_).glob('*.jpg'))]
# create the figure
fig, axes = plt.subplots(nrows=5, ncols=3, figsize=(10, 10), tight_layout=True)
# flatten the axis into a 1-d array to make it easier to access each axes
axes = axes.flatten()
# iterate through axes and associated file
for ax, file in zip(axes, files):
# read the image in
pic = plt.imread(file)
# add the image to the axes
ax.imshow(pic)
# add an axes title; .stem is a pathlib method to get the filename
ax.set(title=file.stem)
# remove ticks / labels
ax.axis('off')
# add a figure title
_ = fig.suptitle('Images from https://www.heroforge.com/', fontsize=18)
# read in all the images, which are all the same size
images = [plt.imread(file) for file in files]
# get heights for images, the number must match the number for nrows
heights = [im[0].shape[0] for im in images[:5]] # [images[0][0].shape[0]] * 5
# get widths for images, the number must match the number for ncols
widths = [im.shape[1] for im in images[:3]] # [images[0].shape[1]] * 3
# set the figure width in inches
fig_width = 9
# calculate the figure width
fig_height = fig_width * sum(heights) / sum(widths)
# create the figure
fig, axes = plt.subplots(nrows=5, ncols=3, figsize=(fig_width, fig_height),
gridspec_kw={'wspace': 0, 'hspace': 0, 'left': 0, 'right': 1,
'bottom': 0, 'top': 1, 'height_ratios': heights})
# flatten the axis into a 1-d array to make it easier to access each axes
axes = axes.flatten()
# iterate through the axes and associated images
for ax, image in zip(axes, images):
# add the image to the axes
ax.imshow(image)
# remove ticks / labels
ax.axis('off')