How can I replace the node-plots from dtreeviz by a custom plot function from me?
Alternatively: I want to replace the dtreeviz-plots with a 2d-histogram: y-axis=y-values, x-axis: values from the split feature, grid over the plot, each grid-cell gets the number of samples inside as color. (If that is already implemented in some package would also be great) In matplotlib the plotting function for that is called hist2d()
I use sklearn to learn a regression decision tree and visualize the results with dtreeviz.
MWE: (see https://github.com/parrt/dtreeviz#regression-decision-tree)
from sklearn.datasets import *
from sklearn import tree
from dtreeviz.trees import *
regr = tree.DecisionTreeRegressor(max_depth=2)
boston = load_boston()
regr.fit(boston.data, boston.target)
viz = dtreeviz(regr,
boston.data,
boston.target,
target_name='price',
feature_names=boston.feature_names)
viz.view()
Now I do have millions of samples in my problem and the resulting .svg
is extremely slow (read 'impossible') to display. I could only use that visualization using downsampling.
Example 2d histogram:
(From https://matplotlib.org/gallery/scales/power_norm.html#sphx-glr-gallery-scales-power-norm-py)
Sorry, but you would have to alter the software as it was not designed to have plug-and-play node figures. It was extremely difficult to convince all of the tools in the chain to work together, even without allowing such flexibility.