I'm working with anytree package in Python for creating and displaying a K-dimension tree, however the tree is not rendered properly.
This is part of my source code for my K-dimensional node:
import numpy as np
from anytree import NodeMixin
class KdData:
def __init__(self) -> None:
self.key = None
class KdNode(KdData, NodeMixin):
def __init__(self, parent=None, children=None, axis: int = 0) -> None:
super().__init__()
self.axis = axis
self.parent = parent
if children:
self.children = children
def set_key(self, key):
self.key = np.array(key)
if self.key is not None:
self.children = [
KdNode(axis=(self.axis + 1) % self.key.shape[0]),
KdNode(axis=(self.axis + 1) % self.key.shape[0])
]
def get_left_node(self):
if self.key is None:
return None
return self.children[0]
def get_right_node(self):
if self.key is None:
return None
return self.children[1]
def insert(self, key):
if self.key is None:
self.set_key(key)
elif key[self.axis] < self.key[self.axis]:
self.get_left_node().insert(key)
else:
self.get_right_node().insert(key)
# Other functions are omitted because it's not relavant to the question
Then this is the tree I created.
tree = KdNode()
tree.set_key(np.array([5,6]))
tree.insert([4,7])
tree.insert([17,16])
tree.insert([7,8])
tree.insert([1,4])
tree.insert([9,13])
And this is my exporter:
def node_attribute(node):
att = ""
att += "shape=plaintext"
return att
def edge_att(source, target):
att = ""
if target.key is None:
att += "style=invis"
return att
from anytree.exporter.dotexporter import DotExporter, UniqueDotExporter
dot_obj = UniqueDotExporter(tree,
nodenamefunc=lambda node: node.key if node.key is not None else "",
nodeattrfunc=lambda node: node_attribute(node),
edgeattrfunc=lambda src, tgt: edge_att(src, tgt))
My obtained result is here. The expected result is that [4 7] node and [17 16] node are in the same rank.
Why was my K-dimensional tree not rendered properly?
If you remove att += "style=invis"
and you use None
instead of ""
in
nodenamefunc=lambda node: node.key if node.key is not None else None,
then you get:
And if you add __repr__
in KdNode
def __repr__(self):
return str(self.key)
and you run print(RenderTree(tree))
then you see:
[5 6]
├── [4 7]
│ ├── [1 4]
│ │ ├── None
│ │ └── None
│ └── None
└── [17 16]
├── [7 8]
│ ├── None
│ └── [ 9 13]
│ ├── None
│ └── None
└── None
And probably this None
makes all problem.
If you use filter=lambda node: node.key is not None
dot_obj = UniqueDotExporter(tree,
nodenamefunc=lambda node: node.key if node.key is not None else "",
nodeattrfunc=lambda node: node_attribute(node),
edgeattrfunc=lambda src, tgt: edge_att(src, tgt),
filter_=lambda node: node.key is not None,
)
then you get:
So maybe you should remove None
from your tree. But I don't know how to do it.
Full code which I used for tests - so everyone can simply copy and test it.
import numpy as np
from anytree import NodeMixin, AnyNode, RenderTree
from anytree.exporter.dotexporter import UniqueDotExporter
from anytree.dotexport import RenderTreeGraph
class KdData:
def __init__(self) -> None:
self.key = None
class KdNode(KdData, NodeMixin):
def __init__(self, parent=None, children=None, axis: int = 0) -> None:
super().__init__()
self.axis = axis
self.parent = parent
if children:
self.children = children
def set_key(self, key):
self.key = np.array(key)
if self.key is not None:
self.children = [
KdNode(axis=(self.axis + 1) % self.key.shape[0]),
KdNode(axis=(self.axis + 1) % self.key.shape[0])
]
def get_left_node(self):
if self.key is None:
return None
return self.children[0]
def get_right_node(self):
if (self.key is None):
return None
return self.children[1]
def insert(self, key):
if self.key is None:
self.set_key(key)
elif key[self.axis] < self.key[self.axis]:
self.get_left_node().insert(key)
else:
self.get_right_node().insert(key)
# Other functions are omitted because it's not relavant to the question
def __repr__(self):
return str(self.key)
# --- main ---
tree = KdNode()
tree.set_key(np.array([5,6]))
tree.insert([4,7])
tree.insert([17,16])
tree.insert([7,8])
tree.insert([1,4])
tree.insert([9,13])
def node_attribute(node):
att = ""
att += "shape=plaintext"
return att
def edge_att(source, target):
att = ""
if target.key is None:
att += "style=invis"
#pass
return att
dot_obj = UniqueDotExporter(tree,
#nodenamefunc=lambda node: node.key if node.key is not None else None,
nodenamefunc=lambda node: node.key if node.key is not None else "",
nodeattrfunc=lambda node: node_attribute(node),
edgeattrfunc=lambda src, tgt: edge_att(src, tgt),
filter_=lambda node: node.key is not None,
)
dot_obj.to_picture('image.png')
print(RenderTree(tree))