I am trying to create this:
The data for the chart is:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries" : [ "Denmark", "Denmark", "Norway", "Norway","Sweden", "Sweden",],
"sites": [4,10,5,8,13,15]
}
df= pd.DataFrame(data)
df['diff'] = df.groupby(['countries'])['sites'].diff()
df['diff'].fillna(df.sites, inplace=True)
df
I am aware that there are packages that do treemaps, (squarify and plotly, to name some), but I have not figured out how to do the one above where the values of the years are added to each other. (or the difference to be exact) and it would be fantastic to learn how to do it in pure matplotlib, if it is not too complex.
Anyone has any pointers? I havent found a lot of info on treemaps on google.
There are two parts to this task.
The first part can get quite involved: people publish scientific papers on the topic. It's not advisable to re-invent the wheel here. However, the second part is quite straightforward and can be done in matplotlib.
The solution below uses squarify to compute a layout using the larger value for each value pair, and then matplotlib to draw two rectangles on top of each other.
import numpy as np
import matplotlib.pyplot as plt
import squarify
from matplotlib import colormaps
from matplotlib.colors import to_rgba
DEFAULT_COLORS = list(zip(colormaps["tab20"].colors[::2],
colormaps["tab20"].colors[1::2]))
def color_to_grayscale(color):
# Adapted from: https://stackoverflow.com/a/689547/2912349
r, g, b, a = to_rgba(color)
return (0.299 * r + 0.587 * g + 0.114 * b) * a
class PairedTreeMap:
def __init__(self, values, colors=DEFAULT_COLORS, labels=None, ax=None, bbox=(0, 0, 200, 100)):
"""
Draw a treemap of value pairs.
values : list[tuple[float, float]]
A list of value pairs.
colors : list[tuple[RGBA, RGBA]]
The corresponding color pairs. Defaults to light/dark tab20 matplotlib color pairs.
labels : list[str]
The labels, one for each pair.
ax : matplotlib.axes._axes.Axes
The matplotlib axis instance to draw on.
bbox : tuple[float, float, float, float]
The (x, y) origin and (width, height) extent of the treemap.
"""
self.ax = self.initialize_axis(ax)
self.rects = self.get_layout(values, bbox)
self.artists = list(self.draw(self.rects, values, colors, self.ax))
if labels:
self.labels = list(self.add_labels(self.rects, labels, values, colors, self.ax))
def get_layout(self, values, bbox):
maxima = np.max(values, axis=1)
order = np.argsort(maxima)[::-1]
normalized_maxima = squarify.normalize_sizes(maxima[order], *bbox[2:])
rects = squarify.padded_squarify(normalized_maxima, *bbox)
reorder = np.argsort(order)
return [rects[ii] for ii in reorder]
def initialize_axis(self, ax=None):
if ax is None:
fig, ax = plt.subplots()
ax.set_aspect("equal")
ax.axis("off")
return ax
def _get_artist_pair(self, rect, value_pair, color_pair):
x, y, w, h = rect["x"], rect["y"], rect["dx"], rect["dy"]
(small, large), (color_small, color_large) = zip(*sorted(zip(value_pair, color_pair)))
ratio = np.sqrt(small / large)
return (plt.Rectangle((x, y), w, h, color=color_large, zorder=1),
plt.Rectangle((x, y), w * ratio, h * ratio, color=color_small, zorder=2))
def draw(self, rects, values, colors, ax):
for rect, value_pair, color_pair in zip(rects, values, colors):
large_patch, small_patch = self._get_artist_pair(rect, value_pair, color_pair)
ax.add_patch(large_patch)
ax.add_patch(small_patch)
yield(large_patch, small_patch)
ax.autoscale_view()
def add_labels(self, rects, labels, values, colors, ax):
for rect, label, value_pair, color_pair in zip(rects, labels, values, colors):
x, y, w, h = rect["x"], rect["y"], rect["dx"], rect["dy"]
# decide a fontcolor based on background brightness
(small, large), (color_small, color_large) = zip(*sorted(zip(value_pair, color_pair)))
ratio = small / large
background_brightness = color_to_grayscale(color_large) if ratio < 0.33 else color_to_grayscale(color_small) # i.e. 0.25 + some fudge
fontcolor = "white" if background_brightness < 0.5 else "black"
yield ax.text(x + w/2, y + h/2, label, va="center", ha="center", color=fontcolor)
if __name__ == "__main__":
values = [
(4, 10),
(13, 15),
(5, 8),
]
colors = [
("red", "coral"),
("royalblue", "cornflowerblue"),
("darkslategrey", "gray"),
]
labels = [
"Denmark",
"Sweden",
"Norway"
]
PairedTreeMap(values, colors=colors, labels=labels, bbox=(0, 0, 100, 100))
plt.show()