pythonmatplotlibpandas

Reset color cycle in Matplotlib


Say I have data about 3 trading strategies, each with and without transaction costs. I want to plot, on the same axes, the time series of each of the 6 variants (3 strategies * 2 trading costs). I would like the "with transaction cost" lines to be plotted with alpha=1 and linewidth=1 while I want the "no transaction costs" to be plotted with alpha=0.25 and linewidth=5. But I would like the color to be the same for both versions of each strategy.

I would like something along the lines of:

fig, ax = plt.subplots(1, 1, figsize=(10, 10))

for c in with_transaction_frame.columns:
    ax.plot(with_transaction_frame[c], label=c, alpha=1, linewidth=1)

****SOME MAGIC GOES HERE TO RESET THE COLOR CYCLE

for c in no_transaction_frame.columns:
    ax.plot(no_transaction_frame[c], label=c, alpha=0.25, linewidth=5)

ax.legend()

What is the appropriate code to put on the indicated line to reset the color cycle so it is "back to the start" when the second loop is invoked?


Solution

  • You can reset the colorcycle to the original with Axes.set_color_cycle. Looking at the code for this, there is a function to do the actual work:

    def set_color_cycle(self, clist=None):
        if clist is None:
            clist = rcParams['axes.color_cycle']
        self.color_cycle = itertools.cycle(clist
    

    And a method on the Axes which uses it:

    def set_color_cycle(self, clist):
        """
        Set the color cycle for any future plot commands on this Axes.
    
        *clist* is a list of mpl color specifiers.
        """
        self._get_lines.set_color_cycle(clist)
        self._get_patches_for_fill.set_color_cycle(clist)
    

    This basically means you can call the set_color_cycle with None as the only argument, and it will be replaced with the default cycle found in rcParams['axes.color_cycle'].

    I tried this with the following code and got the expected result:

    import matplotlib.pyplot as plt
    import numpy as np
    
    for i in range(3):
        plt.plot(np.arange(10) + i)
    
    # for Matplotlib version < 1.5
    plt.gca().set_color_cycle(None)
    # for Matplotlib version >= 1.5
    plt.gca().set_prop_cycle(None)
    
    for i in range(3):
        plt.plot(np.arange(10, 1, -1) + i)
    
    plt.show()
    

    Code output, showing the color cycling reset functionality