I'm trying to mimic the plt.subplots()
behavior, but with custom classes. Rather than return Axes
from subplots()
, I would like to return CustomAxes
. I've looked at the source code and don't understand why I am getting the traceback error below.
I'm able to accomplish what I want without inheriting from Axes
, but I think long term I would like to inherit from Axes
. If you think this is ridiculous and there's a better way, let me know!
Code:
from matplotlib.figure import Figure
from matplotlib.axes import Axes
class CustomAxes(Axes):
def __init__(self, fig, *args, **kwargs):
super().__init__(fig, *args, **kwargs)
def create_plot(self, i):
self.plot([1, 2, 3], [1, 2, 3])
self.set_title(f'Title {i}')
class CustomFigure(Figure):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def subplots(self, *args, **kwargs):
axes = super().subplots(*args, **kwargs)
axes = [CustomAxes(fig=self, *args, **kwargs) for ax in axes.flatten()]
return axes
fig, axes = CustomFigure().subplots(nrows=2, ncols=2)
for i, ax in enumerate(axes, start=1):
ax.create_plot(i=i)
fig.tight_layout()
fig
Traceback:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[60], line 23
20 axes = [CustomAxes(fig=self, *args, **kwargs) for ax in axes.flatten()]
21 return axes
---> 23 fig, axes = CustomFigure().subplots(nrows=2, ncols=2)
24 for i, ax in enumerate(axes, start=1):
25 ax.create_plot(i=i)
Cell In[60], line 20
18 def subplots(self, *args, **kwargs):
19 axes = super().subplots(*args, **kwargs)
---> 20 axes = [CustomAxes(fig=self, *args, **kwargs) for ax in axes.flatten()]
21 return axes
Cell In[60], line 20
18 def subplots(self, *args, **kwargs):
19 axes = super().subplots(*args, **kwargs)
---> 20 axes = [CustomAxes(fig=self, *args, **kwargs) for ax in axes.flatten()]
21 return axes
Cell In[60], line 7
6 def __init__(self, fig, *args, **kwargs):
----> 7 super().__init__(fig, *args, **kwargs)
File ~/repos/test/venv/lib/python3.11/site-packages/matplotlib/axes/_base.py:656, in _AxesBase.__init__(self, fig, facecolor, frameon, sharex, sharey, label, xscale, yscale, box_aspect, forward_navigation_events, *args, **kwargs)
654 else:
655 self._position = self._originalPosition = mtransforms.Bbox.unit()
--> 656 subplotspec = SubplotSpec._from_subplot_args(fig, args)
657 if self._position.width < 0 or self._position.height < 0:
658 raise ValueError('Width and height specified must be non-negative')
File ~/repos/test/venv/lib/python3.11/site-packages/matplotlib/gridspec.py:576, in SubplotSpec._from_subplot_args(figure, args)
574 rows, cols, num = args
575 else:
--> 576 raise _api.nargs_error("subplot", takes="1 or 3", given=len(args))
578 gs = GridSpec._check_gridspec_exists(figure, rows, cols)
579 if gs is None:
TypeError: subplot() takes 1 or 3 positional arguments but 0 were given
Working code without inheritance:
from matplotlib.figure import Figure
class CustomAxes():
def __init__(self, ax):
self.ax = ax
def create_plot(self, i):
self.ax.plot([1, 2, 3], [1, 2, 3])
self.ax.set_title(f'Title {i}')
class CustomFigure(Figure):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def subplots(self, *args, **kwargs):
axes = super().subplots(*args, **kwargs)
axes = [CustomAxes(ax) for ax in axes.flatten()]
return self, axes
fig, axes = CustomFigure().subplots(nrows=2, ncols=2)
for i, ax in enumerate(axes, start=1):
ax.create_plot(i=i)
fig.tight_layout()
fig
I think the error is because you are not passing any position information (i.e. args
) to Axes.__init__
when you call it via super
. However, you can do this more simply without subclassing Figure
, since subplots
lets you specify an Axes
subclass:
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
class CustomAxes(Axes):
def __init__(self, fig, *args, **kwargs):
super().__init__(fig, *args, **kwargs)
def create_plot(self, i):
self.plot([1, 2, 3], [1, 2, 3])
self.set_title(f'Title {i}')
fig, ax_arr = plt.subplots(nrows=2, ncols=2, subplot_kw={'axes_class': CustomAxes})
for i, ax in enumerate(ax_arr.flat, start=1):
ax.create_plot(i=i)
fig.tight_layout()
plt.show()