matplotliboverridingz-order

How to override mpl_toolkits.mplot3d.Axes3D.draw() method?


I'm doing a small project which requires to resolve a bug in matplotlib in order to fix zorders of some ax.patches and ax.collections. More exactly, ax.patches are symbols rotatable in space and ax.collections are sides of ax.voxels (so text must be placed on them). I know so far, that a bug is hidden in draw method of mpl_toolkits.mplot3d.Axes3D: zorder are recalculated each time I move my diagram in an undesired way. So I decided to change definition of draw method in these lines:

    for i, col in enumerate(
            sorted(self.collections,
                   key=lambda col: col.do_3d_projection(renderer),
                   reverse=True)):
        #col.zorder = zorder_offset + i #comment this line
        col.zorder = col.stable_zorder + i #add this extra line
    for i, patch in enumerate(
            sorted(self.patches,
                   key=lambda patch: patch.do_3d_projection(renderer),
                   reverse=True)):
        #patch.zorder = zorder_offset + i #comment this line
        patch.zorder = patch.stable_zorder + i #add this extra line

It's assumed that every object of ax.collection and ax.patch has a stable_attribute which is assigned manually in my project. So every time I run my project, I must be sure that mpl_toolkits.mplot3d.Axes3D.draw method is changed manually (outside my project). How to avoid this change and override this method in any way inside my project?

This is MWE of my project:

import matplotlib.pyplot as plt
import numpy as np
#from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.art3d as art3d
from matplotlib.text import TextPath
from matplotlib.transforms import Affine2D
from matplotlib.patches import PathPatch

class VisualArray:
    def __init__(self, arr, fig=None, ax=None):
        if len(arr.shape) == 1:
            arr = arr[None,None,:]
        elif len(arr.shape) == 2:
            arr = arr[None,:,:]
        elif len(arr.shape) > 3:
            raise NotImplementedError('More than 3 dimensions is not supported')
        self.arr = arr
        if fig is None:
            self.fig = plt.figure()
        else:
            self.fig = fig
        if ax is None:
            self.ax = self.fig.gca(projection='3d')
        else:
            self.ax = ax
        self.ax.azim, self.ax.elev = -120, 30
        self.colors = None

    def text3d(self, xyz, s, zdir="z", zorder=1, size=None, angle=0, usetex=False, **kwargs):
        d = {'-x': np.array([[-1.0, 0.0, 0], [0.0, 1.0, 0.0], [0, 0.0, -1]]),
             '-y': np.array([[0.0, 1.0, 0], [-1.0, 0.0, 0.0], [0, 0.0, 1]]),
             '-z': np.array([[1.0, 0.0, 0], [0.0, -1.0, 0.0], [0, 0.0, -1]])}

        x, y, z = xyz
        if "y" in zdir:
            x, y, z = x, z, y
        elif "x" in zdir:
            x, y, z = y, z, x
        elif "z" in zdir:
            x, y, z = x, y, z

        text_path = TextPath((-0.5, -0.5), s, size=size, usetex=usetex)
        aff = Affine2D()
        trans = aff.rotate(angle)

        # apply additional rotation of text_paths if side is dark
        if '-' in zdir:
            trans._mtx = np.dot(d[zdir], trans._mtx)
        trans = trans.translate(x, y)
        p = PathPatch(trans.transform_path(text_path), **kwargs)
        self.ax.add_patch(p)
        art3d.pathpatch_2d_to_3d(p, z=z, zdir=zdir)
        p.stable_zorder = zorder
        return p

    def on_rotation(self, event):
        vrot_idx = [self.ax.elev > 0, True].index(True)
        v_zorders = 10000 * np.array([(1, -1), (-1, 1)])[vrot_idx]
        for side, zorder in zip((self.side1, self.side4), v_zorders):
            for patch in side:
                patch.stable_zorder = zorder

        hrot_idx = [self.ax.azim < -90, self.ax.azim < 0, self.ax.azim < 90, True].index(True)
        h_zorders = 10000 * np.array([(1, 1, -1, -1), (-1, 1, 1, -1),
                              (-1, -1, 1, 1), (1, -1, -1, 1)])[hrot_idx]
        sides = (self.side3, self.side2, self.side6, self.side5)
        for side, zorder in zip(sides, h_zorders):
            for patch in side:
                patch.stable_zorder = zorder

    def voxelize(self):
        shape = self.arr.shape[::-1]
        x, y, z = np.indices(shape)
        arr = (x < shape[0]) & (y < shape[1]) & (z < shape[2])
        self.ax.voxels(arr, facecolors=self.colors, edgecolor='k')
        for col in self.ax.collections:
            col.stable_zorder = col.zorder

    def labelize(self):
        self.fig.canvas.mpl_connect('motion_notify_event', self.on_rotation)
        s = self.arr.shape
        self.side1, self.side2, self.side3, self.side4, self.side5, self.side6 = [], [], [], [], [], []
        # labelling surfaces of side1 and side4
        surf = np.indices((s[2], s[1])).T[::-1].reshape(-1, 2) + 0.5
        surf_pos1 = np.insert(surf, 2, self.arr.shape[0], axis=1)
        surf_pos2 = np.insert(surf, 2, 0, axis=1)
        labels1 = (self.arr[0]).flatten()
        labels2 = (self.arr[-1]).flatten()
        for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
            t = self.text3d(xyz, label, zdir="z", zorder=10000, size=1, usetex=True, ec="none", fc="k")
            self.side1.append(t)
        for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
            t = self.text3d(xyz, label, zdir="-z", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
            self.side4.append(t)

        # labelling surfaces of side2 and side5
        surf = np.indices((s[2], s[0])).T[::-1].reshape(-1, 2) + 0.5
        surf_pos1 = np.insert(surf, 1, 0, axis=1)
        surf = np.indices((s[0], s[2])).T[::-1].reshape(-1, 2) + 0.5
        surf_pos2 = np.insert(surf, 1, self.arr.shape[1], axis=1)
        labels1 = (self.arr[:, -1]).flatten()
        labels2 = (self.arr[::-1, 0].T[::-1]).flatten()
        for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
            t = self.text3d(xyz, label, zdir="y", zorder=10000, size=1, usetex=True, ec="none", fc="k")
            self.side2.append(t)
        for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
            t = self.text3d(xyz, label, zdir="-y", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
            self.side5.append(t)

        # labelling surfaces of side3 and side6
        surf = np.indices((s[1], s[0])).T[::-1].reshape(-1, 2) + 0.5
        surf_pos1 = np.insert(surf, 0, self.arr.shape[2], axis=1)
        surf_pos2 = np.insert(surf, 0, 0, axis=1)
        labels1 = (self.arr[:, ::-1, -1]).flatten()
        labels2 = (self.arr[:, ::-1, 0]).flatten()
        for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
            t = self.text3d(xyz, label, zdir="x", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
            self.side6.append(t)
        for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
            t = self.text3d(xyz, label, zdir="-x", zorder=10000, size=1, usetex=True, ec="none", fc="k")
            self.side3.append(t)

    def vizualize(self):
        self.voxelize()
        self.labelize()
        plt.axis('off')

arr = np.arange(60).reshape((2,6,5))
va = VisualArray(arr)
va.vizualize()
plt.show()

This is an output I get after external change of ...\mpl_toolkits\mplot3d\axes3d.py file:

enter image description here

This is an output (an unwanted one) I get if no change is done:

enter image description here


Solution

  • What you want to achieve is called Monkey Patching.

    It has its downsides and has to be used with some care (there is plenty of information available under this keyword). But one option could look something like this:

    from matplotlib import artist
    from mpl_toolkits.mplot3d import Axes3D
    
    # Create a new draw function
    @artist.allow_rasterization
    def draw(self, renderer):
        # Your version
        # ...
    
        # Add Axes3D explicitly to super() calls
        super(Axes3D, self).draw(renderer)
    
    # Overwrite the old draw function
    Axes3D.draw = draw
    
    # The rest of your code
    # ...
    

    Caveats here are to import artist for the decorator and the explicit call super(Axes3D, self).method() instead of just using super().method().

    Depending on your use case and to stay compatible with the rest of your code you could also save the original draw function and use the custom only temporarily:

    def draw_custom():
        ...
    
    draw_org = Axes3D.draw
    Axes3D.draw = draw_custom
    
    # Do custom stuff 
    
    Axes3D.draw = draw_org
    
    # Do normal stuff