pythonpandasmatplotlibgeopandasdrawnow

GeoPandas plotting - any way to speed things up?


I'm running a gradient descent algorithm on some geo data. The goal is to assign different areas to different clusters to minimize some objective function. I am trying to make a short movie showing how the algorithm progresses. Right now my approach is to plot the map at each step, then use some other tools to make a little movie from all the static images (pretty simple). But, I have about 3000 areas to plot and the plot command takes a good 90 seconds or more to run, which kills my algorithm.

There are some obvious shortcuts: save images every Nth iteration, save all the steps in a list and make all the images at the end (perhaps in parallel). That's all fine for now, but ultimately I'm aiming for some interactive functionality where a user can put in some parameters and see their map converge in real time. Seems like updating the map on the fly would be best in that case.

Any ideas? Here's the basic command (using the latest dev version of geopandas)

fig, ax = plt.subplots(1,1, figsize=(7,5))
geo_data.plot(column='cluster',ax=ax, cmap='gist_rainbow',linewidth=0)
fig.savefig(filename, bbox_inches='tight', dpi=400)

Also tried something akin to the following (an abbreviated version is below). I open a single plot, and change it and save it with each iteration. Doesn't seem to speed things up at all.

fig, ax = plt.subplots(1,1, figsize=(7,5))
plot = geo_data.plot(ax=ax)
for iter in range(100): #just doing 100 iterations now
    clusters = get_clusters(...)
    for i_d, district in  enumerate(plot.patches):
        if cluster[i] == 1
            district.set_color("#FF0000")
        else:
            district.set_color("#d3d3d3")
    fig.savefig('test'+str(iter)+'.pdf')

update: taken a look at drawnow and other pointers from real-time plotting in while loop with matplotlib, but shapefiles seems to be too big/clunky to work in real time.


Solution

  • I think two aspects can possibly improve the performance: 1) using a matplotlib Collection (the current geopandas implementation is plotting each polygon separately) and 2) only updating the color of the polygons and not plotting it again each iteration (this you already do, but with using a collection this will be much simpler).

    1) Using a matplotlib Collection to plot the Polygons

    This is a possible implementation for a more efficient plotting function with geopandas to plot a GeoSeries of Polygons:

    from matplotlib.collections import PatchCollection
    from matplotlib.patches import Polygon
    import shapely
    
    def plot_polygon_collection(ax, geoms, values=None, colormap='Set1',  facecolor=None, edgecolor=None,
                                alpha=0.5, linewidth=1.0, **kwargs):
        """ Plot a collection of Polygon geometries """
        patches = []
    
        for poly in geoms:
    
            a = np.asarray(poly.exterior)
            if poly.has_z:
                poly = shapely.geometry.Polygon(zip(*poly.exterior.xy))
    
            patches.append(Polygon(a))
    
        patches = PatchCollection(patches, facecolor=facecolor, linewidth=linewidth, edgecolor=edgecolor, alpha=alpha, **kwargs)
    
        if values is not None:
            patches.set_array(values)
            patches.set_cmap(colormap)
    
        ax.add_collection(patches, autolim=True)
        ax.autoscale_view()
        return patches
    

    This is ca an order of 10x faster as the current geopandas plotting method.

    2) Updating the colors of the Polygons

    Once you have the figure, updating the colors of a Collection of Polygons, can be done in one step using the set_array method, where the values in the array indicate the color (converted to a color depending on the colormap)

    E.g. (considering s_poly a GeoSeries with polygons):

    fig, ax = plt.subplots(subplot_kw=dict(aspect='equal'))
    col = plot_polygon_collection(ax, s_poly.geometry)
    # update the color
    col.set_array( ... )
    

    Full example with some dummy data:

    from shapely.geometry import Polygon
    
    p1 = Polygon([(0, 0), (1, 0), (1, 1)])
    p2 = Polygon([(2, 0), (3, 0), (3, 1), (2, 1)])
    p3 = Polygon([(1, 1), (2, 1), (2, 2), (1, 2)])
    s = geopandas.GeoSeries([p1, p2, p3])
    

    Plotting this:

    fig, ax = plt.subplots(subplot_kw=dict(aspect='equal'))
    col = plot_polygon_collection(ax, s.geometry)
    

    gives:

    enter image description here

    Then updating the color with an array indicating the clusters:

    col.set_array(np.array([0,1,0]))
    

    gives

    enter image description here