pythonmatplotlibvisualizationx-axis

How to add additional x-axes but with different scale and color


I have the following plot:

enter image description here

Each of the three lines belongs to a different x-axis scale. For example, the x-axis of the Fully Connected line should range between 0.001 and 0.02; the x-axis of the kNN line should range between 2 and 40. I want to eliminate the current x-axis and have three x-axes, one below the other, each scaled and colored differently.

Here is my code:

## Plot means
x_full = np.linspace(0.001, 0.02, 20)
x_enn = np.linspace(0.05, 1.95, 20)
x_knn = np.linspace(2, 40, 20)
x = np.arange(len(x_full))

fig, ax = plt.subplots(1, 2, figsize=(13.2, 4))

## Set color
ax[0].set_prop_cycle(color=color_list)
ax[1].set_prop_cycle(color=color_list)

## Plot means
ax[0].plot(x, two_moons_acc_mean['full'], label='Fully Connected')
ax[0].plot(x[1:], two_moons_acc_mean['enn'][0.1:], label=r'$\epsilon$-N')
ax[0].plot(x, two_moons_acc_mean['knn'], label=r'$k$NN')

ax[1].plot(x, two_moons_acc_mean['full'], label='Fully Connected')
ax[1].plot(x[1:], two_moons_acc_mean['enn'][0.1:], label=r'$\epsilon$-N')
ax[1].plot(x, two_moons_acc_mean['knn'], label=r'$k$NN')

## Plot standard deviations
ax[0].fill_between(
    x,
    two_moons_acc_mean['full'] - two_moons_acc_std['full'],
    two_moons_acc_mean['full'] + two_moons_acc_std['full'],
    alpha=0.2
)
ax[0].fill_between(
    x[1:],
    two_moons_acc_mean['enn'][0.1:] - two_moons_acc_std['enn'][0.1:],
    two_moons_acc_mean['enn'][0.1:] + two_moons_acc_std['enn'][0.1:],
    alpha=0.2
)
ax[0].fill_between(
    x,
    two_moons_acc_mean['knn'] - two_moons_acc_std['knn'],
    two_moons_acc_mean['knn'] + two_moons_acc_std['knn'],
    alpha=0.2
)

ax[1].fill_between(
    x,
    two_moons_acc_mean['full'] - two_moons_acc_std['full'],
    two_moons_acc_mean['full'] + two_moons_acc_std['full'],
    alpha=0.2
)
ax[1].fill_between(
    x[1:],
    two_moons_acc_mean['enn'][0.1:] - two_moons_acc_std['enn'][0.1:],
    two_moons_acc_mean['enn'][0.1:] + two_moons_acc_std['enn'][0.1:],
    alpha=0.2
)
ax[1].fill_between(
    x,
    two_moons_acc_mean['knn'] - two_moons_acc_std['knn'],
    two_moons_acc_mean['knn'] + two_moons_acc_std['knn'],
    alpha=0.2
)

## Loglog plot
ax[1].set_xscale('log')
ax[1].set_yscale('log')

## Add Legend
ax[0].legend(loc='lower left', ncol=3, frameon=False)
ax[1].legend(loc='lower left', ncol=3, frameon=False)

Solution

  • You should use 3 different axes, one for each line you need to plot.
    The first one can be:

    fig, ax_full = plt.subplots()
    
    full = ax_full.plot(x_full, y_full, color = 'red', label = 'full')
    

    Then you can generate the others with:

    ax_enn = ax_full.twiny()
    

    And plot each line on the respective axis:

    enn = ax_enn.plot(x_enn, y_enn, color = 'blue', label = 'enn')
    

    Then you can move the axis to the bottom with:

    ax_enn.xaxis.set_ticks_position('bottom')
    ax_enn.xaxis.set_label_position('bottom')
    ax_enn.spines['bottom'].set_position(('axes', -0.15))
    

    And finally customize the colors:

    ax_enn.spines['bottom'].set_color('blue')
    ax_enn.tick_params(axis='x', colors='blue')
    ax_enn.xaxis.label.set_color('blue')
    

    Complete Code

    import numpy as np
    import matplotlib.pyplot as plt
    
    
    x_full = np.linspace(0.001, 0.02, 20)
    x_enn = np.linspace(0.05, 1.95, 20)
    x_knn = np.linspace(2, 40, 20)
    
    y_full = np.random.rand(len(x_full))
    y_enn = np.random.rand(len(x_enn))
    y_knn = np.random.rand(len(x_knn))
    
    
    fig, ax_full = plt.subplots()
    
    full = ax_full.plot(x_full, y_full, color = 'red', label = 'full')
    ax_full.spines['bottom'].set_color('red')
    ax_full.tick_params(axis='x', colors='red')
    ax_full.xaxis.label.set_color('red')
    
    
    ax_enn = ax_full.twiny()
    enn = ax_enn.plot(x_enn, y_enn, color = 'blue', label = 'enn')
    ax_enn.xaxis.set_ticks_position('bottom')
    ax_enn.xaxis.set_label_position('bottom')
    ax_enn.spines['bottom'].set_position(('axes', -0.15))
    ax_enn.spines['bottom'].set_color('blue')
    ax_enn.tick_params(axis='x', colors='blue')
    ax_enn.xaxis.label.set_color('blue')
    
    
    ax_knn = ax_full.twiny()
    knn = ax_knn.plot(x_knn, y_knn, color = 'green', label = 'knn')
    ax_knn.xaxis.set_ticks_position('bottom')
    ax_knn.xaxis.set_label_position('bottom')
    ax_knn.spines['bottom'].set_position(('axes', -0.3))
    ax_knn.spines['bottom'].set_color('green')
    ax_knn.tick_params(axis='x', colors='green')
    ax_knn.xaxis.label.set_color('green')
    
    
    lines = full + enn + knn
    labels = [l.get_label() for l in lines]
    ax_full.legend(lines, labels)
    
    plt.tight_layout()
    
    plt.show()
    

    enter image description here