Is there a way to simplify the following Python code? The only difference between these four subplots is the viewing angles using the view_init() function.
def plot_example(mydata_dataframe):
fig = plt.figure(figsize=[15,15])
#Create subplots
ax1 = fig.add_subplot(2,2,1, projection='3d')
ax1.scatter3D(data.x, data.y, data.z, c=data.z, cmap='Blues')
ax1.view_init(0,90)
ax1.set_xlabel('x', color ='red')
ax1.set_ylabel('y', color ='red')
ax1.set_zlabel('z', color ='red')
ax1.set_xlim(0, 14)
ax1.set_ylim(-6, 6)
ax1.set_zlim(0, 8.5)
ax2 = fig.add_subplot(2,2,2, projection='3d')
ax2.scatter3D(data.x, data.y, data.z, c=data.z, cmap='Blues')
ax2.view_init(45,0)
ax2.set_xlabel('x', color ='red')
ax2.set_ylabel('y', color ='red')
ax2.set_zlabel('z', color ='red')
ax2.set_xlim(0, 14)
ax2.set_ylim(-6, 6)
ax2.set_zlim(0, 8.5)
ax3 = fig.add_subplot(2,2,3, projection='3d')
ax3.scatter3D(data.x, data.y, data.z, c=data.z, cmap='Blues')
ax3.view_init(35,45)
ax3.set_xlabel('x', color ='red')
ax3.set_ylabel('y', color ='red')
ax3.set_zlabel('z', color ='red')
ax3.set_xlim(0, 14)
ax3.set_ylim(-6, 6)
ax3.set_zlim(0, 8.5)
ax4 = fig.add_subplot(2,2,4, projection='3d')
ax4.scatter3D(data.x, data.y, data.z, c=data.z, cmap='Blues')
ax4.view_init(20,40)
ax4.set_xlabel('x', color ='red')
ax4.set_ylabel('y', color ='red')
ax4.set_zlabel('z', color ='red')
ax4.set_xlim(0, 14)
ax4.set_ylim(-6, 6)
ax4.set_zlim(0, 8.5)
I tried this:-
import matplotlib.pyplot as plt
def plot_example(mydata_dataframe):
fig = plt.figure(figsize=[15, 15])
# Create subplots and plot data
for i, ax in enumerate([fig.add_subplot(2, 2, i + 1, projection='3d') for i in range(4)]):
ax.scatter3D(data_df.x, data_df.y, data_df.z, c=data_df.z, cmap='Blues')
ax.view_init(*[30 * i, 30 * (i + 1)]) # Set different viewing angles
ax.set_xlabel('x', color='red')
ax.set_ylabel('y', color='red')
ax.set_zlabel('z', color='red')
# Set axis limits for all subplots
plt.axis([0, 14, -6, 6, 0, 8])
plt.show()
But got the error "TypeError: the first argument to axis() must be an iterable of the form [xmin, xmax, ymin, ymax]"
First of all, the plt.axis()
function only has 4 parameters allowed, instead of the six you have mentioned in your code. Since your z limits are constant, you can add ax.set_zlim(0, 8.5)
to your for loop itself. In fact, you can add all ax.set_xlim
and ax.set_ylim
to the loop as well.
In summary plt.axis() only works for two dimensions.
Your code should look something like this:
import matplotlib.pyplot as plt
def plot3Ddata(data_df):
fig = plt.figure(figsize=[15, 15])
# Create subplots and plot data
for i, ax in enumerate([fig.add_subplot(2, 2, i + 1, projection='3d') for i in range(4)]):
ax.scatter3D(data_df.x, data_df.y, data_df.z, c=data_df.z, cmap='Blues')
ax.view_init(*[30 * i, 30 * (i + 1)]) # Set different viewing angles
ax.set_xlabel('x', color='red')
ax.set_ylabel('y', color='red')
ax.set_zlabel('z', color='red')
ax.set_xlim(0, 14)
ax.set_ylim(-6, 6)
ax.set_zlim(0, 8.5)
plt.show()