pythonmatplotlibplotmplot3d

# Plot Imshow on a Torus (lattice with PBC)

Is there any "simple" solution in matplotlib or another Python package to plot a square lattice on a 2d torus (aka lattice with periodic boundary conditions)?

Assume i have a simple 2D array

# ...
a = np.random.random((50, 50))
plt.imshow(a)

I would like to wrap this plane into a torus, which can be achieved e.g. with

from mpl_toolkits.mplot3d import Axes3D
# Generating Torus Mesh
angle = np.linspace(0, 2 * np.pi, 100)
theta, phi = np.meshgrid(angle, angle)
r, R = .25, 1.
X = (R + r * np.cos(phi)) * np.cos(theta)
Y = (R + r * np.cos(phi)) * np.sin(theta)
Z = r * np.sin(phi)

fig = plt.figure()

ax.set_xlim3d(-1, 1)
ax.set_ylim3d(-1, 1)
ax.set_zlim3d(-1, 1)
ax.plot_surface(X, Y, Z, rstride = 1, cstride = 1)

plt.show()

I thought about encoding somehow the information on X and Y value in a colormap to pass to plot_surface's option cmap. The colormap shall have each color of the image according to content of array a.

Ideas?

Solution

• First, let's introduce the floor function.

theta and phi increases "continuously" from 0 to 2*pi. We can scale them to the dimensions of the matrix a, and use the property of the floor function to compute the indexes and extract the proper color from your matrix.

Here is the code:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import matplotlib

# random colors
theta_dim, phi_dim = 100, 30
a = np.random.random((theta_dim, phi_dim))

# Generating Torus Mesh
angle = np.linspace(0, 2 * np.pi, 100)
theta, phi = np.meshgrid(angle, angle)
r, R = .25, 1.
X = (R + r * np.cos(phi)) * np.cos(theta)
Y = (R + r * np.cos(phi)) * np.sin(theta)
Z = r * np.sin(phi)

# compute the indexes
t, p = [var / (2 * np.pi) for var in [theta, phi]]
t = np.floor((t - 0.5) * a.shape[0]).astype(int) + 1
p = np.floor((p - 0.5) * a.shape[1]).astype(int) + 1
# extract the color value from the matrix
colors = a[t, p]
# apply a colormap to the normalized color values
norm = Normalize(vmin=colors.min(), vmax=colors.max())
cmap = matplotlib.colormaps.get_cmap("viridis")
normalized_colors = cmap(norm(colors))

fig = plt.figure()