pythonmatplotlibsubplottwinx

Dual x-axis with same data, different scale


I'd like to plot some data in Python using two different x-axes. For ease of explanation, I will say that I want to plot light absorption data, which means I plot absorbance vs. wavelength (nm) or energy (eV). I want to have a plot where the bottom axis denotes the wavelength in nm, and the top axis denotes energy in eV. The two are not linearly dependent (as you can see in my MWE below).

My full MWE:

import numpy as np
import matplotlib.pyplot as plt 
import scipy.constants as constants

# Converting wavelength (nm) to energy (eV)
def WLtoE(wl):
    # E = h*c/wl            
    h = constants.h         # Planck constant
    c = constants.c         # Speed of light
    J_eV = constants.e      # Joule-electronvolt relationship
    
    wl_nm = wl * 10**(-9)   # convert wl from nm to m
    E_J = (h*c) / wl_nm     # energy in units of J
    E_eV = E_J / J_eV       # energy in units of eV
    
    return E_eV

x = np.arange(200,2001,5)
x_mod = WLtoE(x)
y = 2*x + 3

fig, ax1 = plt.subplots()

ax2 = ax1.twiny()
ax1.plot(x, y, color='red')
ax2.plot(x_mod, y, color = 'green')

ax1.set_xlabel('Wavelength (nm)', fontsize = 'large', color='red')
ax1.set_ylabel('Absorbance (a.u.)', fontsize = 'large')
ax1.tick_params(axis='x', colors='red')

ax2.set_xlabel('Energy (eV)', fontsize='large', color='green')
ax2.tick_params(axis='x', colors='green')
ax2.spines['top'].set_color('green')
ax2.spines['bottom'].set_color('red')

plt.tight_layout()
plt.show()

This yields:

this figure

Now this is close to what I want, but I'd like to solve the following two issues:

  1. One of the axes needs to be reversed - high wavelength equals low energy but this is not the case in the figure. I tried using x_mod = WLtoE(x)[::-1] for example but this does not solve this issue.
  2. Since the axes are not linearly dependent, I'd like the top and bottom axis to "match". For example, right now 1000 nm lines up with 3 eV (more or less) but in reality 1000 nm corresponds to 1.24 eV. So one of the axes (preferably the bottom, wavelength axis) needs to be condensed/expanded to match the correct value of energy at the top. In other words, I'd like the red and green curve to coincide.

I appreciate any and all tips & tricks to help me make a nice plot! Thanks in advance.

** EDIT ** DeX97's answer solved my problem perfectly although I made some minor changes, as you can see below. I just made some changes in the way I plotted things, defining the functions like DeX97 worked perfectly.

Edited code for plotting

fig, ax1 = plt.subplots()

ax1.plot(WLtoE(x), y)
ax1.set_xlabel('Energy (eV)', fontsize = 'large')
ax1.set_ylabel('Absorbance (a.u.)', fontsize = 'large')

# Create the second x-axis on which the wavelength in nm will be displayed
ax2 = ax1.secondary_xaxis('top', functions=(EtoWL, WLtoE))
ax2.set_xlabel('Wavelength (nm)', fontsize='large')
# Invert the wavelength axis
ax2.invert_xaxis()

# Get ticks from ax1 (energy)
E_ticks = ax1.get_xticks()
E_ticks = preventDivisionByZero(E_ticks)

# Make own array of wavelength ticks, so they are round numbers
# The values are not linearly spaced, but that is the idea.
wl_ticks = np.asarray([200, 250, 300, 350, 400, 500, 600, 750, 1000, 2000])

# Set the ticks for ax2 (wl)
ax2.set_xticks(wl_ticks)

# Make the values on ax2 (wavelength) integer values
ax2.xaxis.set_major_formatter(FormatStrFormatter('%i'))

plt.tight_layout()
plt.show()

Solution

  • In your code example, you plot the same data twice (albeit transformed using E=h*c/wl). I think it would be sufficient to only plot the data once, but create two x-axes: one displaying the wavelength in nm and one displaying the corresponding energy in eV.

    Consider the adjusted code below:

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.ticker import FormatStrFormatter
    import scipy.constants as constants
    from sys import float_info
    
    # Function to prevent zero values in an array
    def preventDivisionByZero(some_array):
        corrected_array = some_array.copy()
        for i, entry in enumerate(some_array):
            # If element is zero, set to some small value
            if abs(entry) < float_info.epsilon:
                corrected_array[i] = float_info.epsilon
        
        return corrected_array
    
    
    # Converting wavelength (nm) to energy (eV)
    def WLtoE(wl):
        # Prevent division by zero error
        wl = preventDivisionByZero(wl)
    
        # E = h*c/wl            
        h = constants.h         # Planck constant
        c = constants.c         # Speed of light
        J_eV = constants.e      # Joule-electronvolt relationship
        
        wl_nm = wl * 10**(-9)   # convert wl from nm to m
        E_J = (h*c) / wl_nm     # energy in units of J
        E_eV = E_J / J_eV       # energy in units of eV
        
        return E_eV
        
    
    # Converting energy (eV) to wavelength (nm)
    def EtoWL(E):
        # Prevent division by zero error
        E = preventDivisionByZero(E)
        
        # Calculates the wavelength in nm
        return constants.h * constants.c / (constants.e * E) * 10**9
    
    
    x = np.arange(200,2001,5)
    y = 2*x + 3
    
    fig, ax1 = plt.subplots()
    
    ax1.plot(x, y, color='black')
    
    ax1.set_xlabel('Wavelength (nm)', fontsize = 'large')
    ax1.set_ylabel('Absorbance (a.u.)', fontsize = 'large')
    
    # Invert the wavelength axis
    ax1.invert_xaxis()
    
    # Create the second x-axis on which the energy in eV will be displayed
    ax2 = ax1.secondary_xaxis('top', functions=(WLtoE, EtoWL))
    ax2.set_xlabel('Energy (eV)', fontsize='large')
    
    # Get ticks from ax1 (wavelengths)
    wl_ticks = ax1.get_xticks()
    wl_ticks = preventDivisionByZero(wl_ticks)
    
    # Based on the ticks from ax1 (wavelengths), calculate the corresponding
    # energies in eV
    E_ticks = WLtoE(wl_ticks)
    
    # Set the ticks for ax2 (Energy)
    ax2.set_xticks(E_ticks)
    
    # Allow for two decimal places on ax2 (Energy)
    ax2.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    
    plt.tight_layout()
    plt.show()
    

    First of all, I define the preventDivisionByZero utility function. This function takes an array as input and checks for values that are (approximately) equal to zero. Subsequently, it will replace these values with a small number (sys.float_info.epsilon) that is not equal to zero. This function will be used in a few places to prevent division by zero. I will come back to why this is important later.

    After this function, your WLtoE function is defined. Note that I added the preventDivisionByZero function at the top of your function. In addition, I defined a EtoWL function, which does the opposite compared to your WLtoE function.

    Then, you generate your dummy data and plot it on ax1, which is the x-axis for the wavelength. After setting some labels, ax1 is inverted (as was requested in your original post).

    Now, we create the second axis for the energy using ax2 = ax1.secondary_xaxis('top', functions=(WLtoE, EtoWL)). The first argument indicates that the axis should be placed at the top of the figure. The second (keyword) argument is given a tuple containing two functions: the first function is the forward transform, while the second function is the backward transform. See Axes.secondary_axis for more information. Note that matplotlib will pass values to these two functions whenever necessary. As these values can be equal to zero, it is important to handle those cases. Hence, the preventDivisionByZero function! After creating the second axis, the label is set.

    Now we have two x-axes, but the ticks on both axis are at different locations. To 'solve' this, we store the tick locations of the wavelength x-axis in wl_ticks. After ensuring there are no zero elements using the preventDivisionByZero function, we calculate the corresponding energy values using the WLtoE function. These corresponding energy values are stored in E_ticks. Now we simply set the tick locations of the second x-axis equal to the values in E_ticks using ax2.set_xticks(E_ticks).

    To allow for two decimal places on the second x-axis (energy), we use ax2.xaxis.set_major_formatter(FormatStrFormatter('%.2f')). Of course, you can choose the desired number of decimal places yourself.

    The code given above produces the following graph: output of python code given above