pythonmatplotlibmplot3dmatplotlib-3dbar3d

bar plot in 3d following a given line


I want to draw a bar plot in 3d. I know how to do that using the following code:

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111, projection='3d')
nbins = 50
# for c, z in zip(['r', 'g', 'b', 'y'], [30, 20, 10, 0]):
ys = np.random.normal(loc=10, scale=10, size=2000)

hist, bins = np.histogram(ys, bins=nbins)
xs = (bins[:-1] + bins[1:])/2

ax.bar(xs, hist, zs=30, zdir='y', color='r', ec='r', alpha=0.8)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

plt.show()

This will render something like this: https://i.sstatic.net/KK2If.png

However, my goal is to make the bar plot follows a line that I give as parameter. For example here, the parameter zdir='y' makes the plot have its current direction. Ideally I want to pass a parameter that makes the plot follows a given line for example y=2x+1.

Could someone help arrive at the desired result?


Solution

  • One way to achieve that is by using Poly3DCollection: the idea is to compute the coordinates and orientation of each bar, then add it to the plot.

    The position and orientation of each bar can be computed starting from a rectangle in 3D space and applying the appropriate transformation matrix.

    If you are going to change the curve, you will also need to change the bar width.

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
    from matplotlib.patches import Rectangle
    
    ################
    # Generates data
    ################
    nbins = 50
    ys = np.random.normal(loc=10, scale=10, size=2000)
    hist, bins = np.histogram(ys, bins=nbins)
    xs = (bins[:-1] + bins[1:])/2
    
    #################################################
    # Create a single bar and a transformation matrix
    #################################################
    
    # rectangle of width=height=1, centered at x,y=0
    # covering the z range [0, height]
    rect = np.array([
        [-0.5, 0, 0, 1],
        [0.5, 0, 0, 1],
        [0.5, 0, 1, 1],
        [-0.5, 0, 1, 1],
    ])
    def translate(x, y, z):
        d = np.eye(4, dtype=float)
        d[:, -1] = [x, y, z, 1]
        return d
    def scale(sx, sy, sz):
        d = np.eye(4, dtype=float)
        d[np.diag_indices(4)] = [sx, sy, sz, 1]
        return d
    def rotate(t):
        d = np.eye(4, dtype=float)
        d[:2, :2] = np.array([
        [np.cos(t), -np.sin(t)],
        [np.sin(t), np.cos(t)]])
        return d
    def transformation_matrix(t, x, y, z, w, h):
        return translate(x, y, z) @ rotate(t) @ scale(w, 1, h)
    def apply_transform(t, x, y, z, w, h):
        """Apply the transformation matrix to the rectangle"""
        verts = transformation_matrix(t, x, y, z, w, h) @ rect.T
        return verts.T
    
    #################
    # Create the plot
    #################
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    
    curve = lambda x: 2 * x + 1
    # curve = lambda x: np.sin(0.05 * x)
    
    xstep = abs(xs[0] - xs[1])
    # NOTE: chose an appropriate bar width
    width = xstep * 1.5
    
    ys = curve(xs)
    # previous bar coordinates
    xp = np.roll(xs, 1)
    yp = np.roll(ys, 1)
    xp[0] = xs[0] - xstep
    yp[0] = curve(xp[0])
    # compute the orientation of the bars
    theta = np.arctan2((ys - yp), (xs - xp))
    
    # customize the appearance of the bar
    facecolor = "tab:red"
    edgecolor = "k"
    linewidth = 0
    # loop to add each bar
    for x, y, t, h in zip(xs, ys, theta, hist):
        verts_matrix = apply_transform(t, x, y, 0, width, h)
        x, y, z = verts_matrix[:, 0], verts_matrix[:, 1], verts_matrix[:, 2]
        verts = [list(zip(x, y, z))]
        c = Poly3DCollection(verts, facecolor=facecolor, edgecolor=edgecolor, linewidth=linewidth)
        ax.add_collection3d(c)
    
    # eventually show a legend
    ax.legend([Rectangle((0, 0), 1, 1, facecolor=facecolor, edgecolor=edgecolor, linewidth=linewidth)], ["Bar Plot"])
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    
    ax.set_xlim(xs.min(), xs.max())
    ax.set_ylim(ys.min(), ys.max())
    ax.set_zlim(0, 100)
    
    plt.show()
    

    enter image description here

    EDIT to explain what is going on:

    Consider a generic rectangle with 4 vertices: bottom left, bottom right, top right, top left. For simplicity, let's fix width=height=1. Then we consider a reference system x,y,z and we draw this rectangle. The coordinates of vertices are: bottom left (-0.5, 0, 0), bottom right (0.5, 0, 0), top right (0.5, 0, 1) and top left (-0.5, 0, 1). Note that this rectangle is centered around the zero in the x direction. If we move it to x=2, then it will be centered at that location. You can see the above coordinates in rect: why does this variable has a fourth column filled with ones? That's a mathematical trick to be able to apply a translation matrix to the vertices.

    Let's talk about transformation matrices (wikipedia has a nice page about it). Consider again our generic rectangle: we can scale it, rotate it and translate it to get a new rectangle in the position and orientation we want.

    So, the code above defines a function for each transformation, translate, scale, rotate. Turns out that we can multiply together multiple transformation matrices to get an overall transformation: that's what transformation_matrix does, it combines the aforementioned transformations into a single matrix.

    Finally, I used apply_transform to apply the transformation matrix to the generic rectangle: this will compute the coordinates of the vertices of the new rectangle, in the specified position/orientation with the specified size (width, height).