pythonmatplotlibplotmplot3d

# Plot Imshow on a Torus (lattice with PBC)

Is there any "simple" solution in `matplotlib` or another Python package to plot a square lattice on a 2d torus (aka lattice with periodic boundary conditions)?

Assume i have a simple 2D array

``````# ...
a = np.random.random((50, 50))
plt.imshow(a)
`````` I would like to wrap this plane into a torus, which can be achieved e.g. with

``````from mpl_toolkits.mplot3d import Axes3D
# Generating Torus Mesh
angle = np.linspace(0, 2 * np.pi, 100)
theta, phi = np.meshgrid(angle, angle)
r, R = .25, 1.
X = (R + r * np.cos(phi)) * np.cos(theta)
Y = (R + r * np.cos(phi)) * np.sin(theta)
Z = r * np.sin(phi)

fig = plt.figure()

ax.set_xlim3d(-1, 1)
ax.set_ylim3d(-1, 1)
ax.set_zlim3d(-1, 1)
ax.plot_surface(X, Y, Z, rstride = 1, cstride = 1)

plt.show()
`````` I thought about encoding somehow the information on `X` and `Y` value in a colormap to pass to `plot_surface`'s option `cmap`. The colormap shall have each color of the image according to content of array `a`.

Ideas?

Solution

• First, let's introduce the floor function. `theta` and `phi` increases "continuously" from `0` to `2*pi`. We can scale them to the dimensions of the matrix `a`, and use the property of the floor function to compute the indexes and extract the proper color from your matrix.

Here is the code:

``````import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import matplotlib

# random colors
theta_dim, phi_dim = 100, 30
a = np.random.random((theta_dim, phi_dim))

# Generating Torus Mesh
angle = np.linspace(0, 2 * np.pi, 100)
theta, phi = np.meshgrid(angle, angle)
r, R = .25, 1.
X = (R + r * np.cos(phi)) * np.cos(theta)
Y = (R + r * np.cos(phi)) * np.sin(theta)
Z = r * np.sin(phi)

# compute the indexes
t, p = [var / (2 * np.pi) for var in [theta, phi]]
t = np.floor((t - 0.5) * a.shape).astype(int) + 1
p = np.floor((p - 0.5) * a.shape).astype(int) + 1
# extract the color value from the matrix
colors = a[t, p]
# apply a colormap to the normalized color values
norm = Normalize(vmin=colors.min(), vmax=colors.max())
cmap = matplotlib.colormaps.get_cmap("viridis")
normalized_colors = cmap(norm(colors))

fig = plt.figure() 