pythonpython-3.xmatplotlib

Issues with Axes (matplotlib) inheritance


I'm trying to mimic the plt.subplots() behavior, but with custom classes. Rather than return Axes from subplots(), I would like to return CustomAxes. I've looked at the source code and don't understand why I am getting the traceback error below.

I'm able to accomplish what I want without inheriting from Axes, but I think long term I would like to inherit from Axes. If you think this is ridiculous and there's a better way, let me know!

Code:

from matplotlib.figure import Figure
from matplotlib.axes import Axes

class CustomAxes(Axes):
    
    def __init__(self, fig, *args, **kwargs):
        super().__init__(fig, *args, **kwargs)
    
    def create_plot(self, i):
        self.plot([1, 2, 3], [1, 2, 3])
        self.set_title(f'Title {i}')

class CustomFigure(Figure):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def subplots(self, *args, **kwargs):
        axes = super().subplots(*args, **kwargs)
        axes = [CustomAxes(fig=self, *args, **kwargs) for ax in axes.flatten()]
        return axes

fig, axes = CustomFigure().subplots(nrows=2, ncols=2)
for i, ax in enumerate(axes, start=1):
    ax.create_plot(i=i)
fig.tight_layout()

fig

Traceback:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[60], line 23
     20         axes = [CustomAxes(fig=self, *args, **kwargs) for ax in axes.flatten()]
     21         return axes
---> 23 fig, axes = CustomFigure().subplots(nrows=2, ncols=2)
     24 for i, ax in enumerate(axes, start=1):
     25     ax.create_plot(i=i)

Cell In[60], line 20
     18 def subplots(self, *args, **kwargs):
     19     axes = super().subplots(*args, **kwargs)
---> 20     axes = [CustomAxes(fig=self, *args, **kwargs) for ax in axes.flatten()]
     21     return axes

Cell In[60], line 20
     18 def subplots(self, *args, **kwargs):
     19     axes = super().subplots(*args, **kwargs)
---> 20     axes = [CustomAxes(fig=self, *args, **kwargs) for ax in axes.flatten()]
     21     return axes

Cell In[60], line 7
      6 def __init__(self, fig, *args, **kwargs):
----> 7     super().__init__(fig, *args, **kwargs)

File ~/repos/test/venv/lib/python3.11/site-packages/matplotlib/axes/_base.py:656, in _AxesBase.__init__(self, fig, facecolor, frameon, sharex, sharey, label, xscale, yscale, box_aspect, forward_navigation_events, *args, **kwargs)
    654 else:
    655     self._position = self._originalPosition = mtransforms.Bbox.unit()
--> 656     subplotspec = SubplotSpec._from_subplot_args(fig, args)
    657 if self._position.width < 0 or self._position.height < 0:
    658     raise ValueError('Width and height specified must be non-negative')

File ~/repos/test/venv/lib/python3.11/site-packages/matplotlib/gridspec.py:576, in SubplotSpec._from_subplot_args(figure, args)
    574     rows, cols, num = args
    575 else:
--> 576     raise _api.nargs_error("subplot", takes="1 or 3", given=len(args))
    578 gs = GridSpec._check_gridspec_exists(figure, rows, cols)
    579 if gs is None:

TypeError: subplot() takes 1 or 3 positional arguments but 0 were given

Working code without inheritance:

from matplotlib.figure import Figure


class CustomAxes():

    def __init__(self, ax):
        self.ax = ax

    def create_plot(self, i):
        self.ax.plot([1, 2, 3], [1, 2, 3])
        self.ax.set_title(f'Title {i}')


class CustomFigure(Figure):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def subplots(self, *args, **kwargs):
        axes = super().subplots(*args, **kwargs)
        axes = [CustomAxes(ax) for ax in axes.flatten()]
        return self, axes


fig, axes = CustomFigure().subplots(nrows=2, ncols=2)
for i, ax in enumerate(axes, start=1):
    ax.create_plot(i=i)
fig.tight_layout()

fig

Solution

  • I think the error is because you are not passing any position information (i.e. args) to Axes.__init__ when you call it via super. However, you can do this more simply without subclassing Figure, since subplots lets you specify an Axes subclass:

    import matplotlib.pyplot as plt
    from matplotlib.axes import Axes
    
    class CustomAxes(Axes):
    
        def __init__(self, fig, *args, **kwargs):
            super().__init__(fig, *args, **kwargs)
    
        def create_plot(self, i):
            self.plot([1, 2, 3], [1, 2, 3])
            self.set_title(f'Title {i}')
    
    
    fig, ax_arr = plt.subplots(nrows=2, ncols=2, subplot_kw={'axes_class': CustomAxes})
    for i, ax in enumerate(ax_arr.flat, start=1):
        ax.create_plot(i=i)
    fig.tight_layout()
    
    plt.show()
    

    enter image description here