I intend to make a plot of Suspension Travel x Distance of a train. I have a dataset with these information and with information relating if the train is in a straight track or in a curve. I want to mimic the following image, but I have some additional information I have to add (bridge location and tunnel for example). The problem is that what I tried takes almost four minutes to run.
Plot I want to copy
# Plot of the line I want to add
def transition_line(xmin, xmax):
for i in range((df_irv['Distance'] - xmin).abs().idxmin(), (df_irv['Distance'] - xmax).abs().idxmin()):
plt.plot([df_irv['Distance'][i], df_irv['Distance'][i+1]], [max(df_irv['SuspTravel']) + 0.5]*2, color='red' if df_irv['Element'][i] == 'CURVA' else 'blue', linewidth=10, alpha=0.5)
# Function to plot the data with adjustable x-axis limits
def plot_graph(xmin, xmax, sensors):
plt.figure(figsize=(10, 5))
plt.plot(df_irv['Distance'], df_irv[sensors], label='Suspension Sensor')
plt.xlim(xmin, xmax)
plt.xlabel('Distance (Km)')
plt.ylabel('Suspension Sensor')
plt.title('Suspension Sensor vs Distance')
plt.legend()
plt.grid(True)
transition_line(xmin, xmax)
plt.show()
# Create sliders for x-axis limits
xmin_slider = IntSlider(value=0, min=0, max=df_irv['Distance'].max(), step=1, description='X min')
xmax_slider = IntSlider(value=20, min=0, max=df_irv['Distance'].max(), step=1, description='X max')
# Interactive plot
interact(plot_graph, xmin=xmin_slider, xmax=xmax_slider, sensors = ['SuspTravel', 'Roll', 'Bounce'])
Image produced by my attempt
Calling plt.plot()
many times in a loop can be slow. (Also, calculating the same max(df_irv['SuspTravel'])
for each step in the loop can be avoided by calculating it once before the start of the loop.)
To speed up the drawing of the short lines, a similar approach can be used as in multicolored lines code from matplotlib's tutorial. Instead of a loop, numpy's arrays are much faster (arrays are implemented in optimized C code).
Here is how the code could look like:
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import pandas as pd
import numpy as np
# create some dummy test data
df_irv = pd.DataFrame({'Distance': np.random.randint(10, 100, 1000).cumsum(),
'SuspTravel': np.random.randn(1000).cumsum() * 100,
'Element': np.random.choice(['CURVA', 'RECTA'], 1000, p=[.1, .9])})
fig, ax = plt.subplots()
# add plot
ax.plot('Distance', 'SuspTravel', data=df_irv)
# add "transition line"
xmin = df_irv['Distance'].min()
xmax = df_irv['Distance'].max()
id_xmin = (df_irv['Distance'] - xmin).abs().idxmin()
id_xmax = (df_irv['Distance'] - xmax).abs().idxmin()
xvals = df_irv['Distance'][id_xmin:id_xmax + 1]
yvals = np.full(id_xmax - id_xmin, max(df_irv['SuspTravel']) + 0.5)
colors = df_irv['Element'][id_xmin:id_xmax].map({'CURVA': 'red', 'RECTA': 'blue'})
segments = np.c_[xvals[:-1], yvals, xvals[1:], yvals].reshape(-1, 2, 2)
lines = LineCollection(segments, colors=colors)
lines.set_linewidth(10)
line = ax.add_collection(lines)
plt.show()