pythonmatplotlibmplot3dcontourfmatplotlib-3d

Set_xlim and set_ylim not working for contourf in 3d plot


I want to create a 2d slice contour plot in 3d with the range of x and y larger than the given xlim and ylim. However, when I set xlim and ylim the contour seems to be extended outside of the axes. I would be most appreciative if there is a way to limit the contour inside the axes.

import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import axes3d
import numpy as np

ax = plt.figure().add_subplot(projection='3d')
X, Y, Z = axes3d.get_test_data(0.05)

# Plot the 3D surface
#ax.plot_surface(X, Y, Z, rstride=8, cstride=8, alpha=0.3)

# Plot projections of the contours for each dimension.  By choosing offsets
# that match the appropriate axes limits, the projected contours will sit on
# the 'walls' of the graph
cset = ax.contourf(X, Y, Z, zdir='z', offset=-100, cmap=cm.coolwarm)

ax.set_xlim(-20, 20)
ax.set_ylim(-20, 20)
ax.set_zlim(-100, 100)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

plt.show()

Figure: enter image description here


Solution

  • You could filter Z with numpy.where:

    Z = np.where((X > 20) | (X < -20), None, Z)
    Z = np.where((Y > 20) | (Y < -20), None, Z)
    

    Example:

    import matplotlib.pyplot as plt
    from matplotlib import cm
    from mpl_toolkits.mplot3d import axes3d
    import numpy as np
    
    ax = plt.figure().add_subplot(projection='3d')
    X, Y, Z = axes3d.get_test_data(0.05)
    
    Z = np.where((X > 20) | (X < -20), None, Z)
    Z = np.where((Y > 20) | (Y < -20), None, Z)
    
    # Plot the 3D surface
    #ax.plot_surface(X, Y, Z, rstride=8, cstride=8, alpha=0.3)
    
    # Plot projections of the contours for each dimension.  By choosing offsets
    # that match the appropriate axes limits, the projected contours will sit on
    # the 'walls' of the graph
    cset = ax.contourf(X, Y, Z, zdir='z', offset=-100, cmap=cm.coolwarm)
    
    ax.set_xlim(-20, 20)
    ax.set_ylim(-20, 20)
    ax.set_zlim(-100, 100)
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    
    plt.show()
    

    enter image description here