I am working on an application where my database objects often have multiple parents and multiple children, and would like to create a SQLAlchemy query that will return all descendants of an object.
Realizing that I am basically trying to store a graph in a SQL database, I found that setting up a self-referential many-to-many schema got me most of the way there, but I am having trouble writing the query to return all descendants of a node. I tried to follow SQLA's recursive CTE example, which looks like the right approach, but have been running into problems getting it to work. I think my situation is different from the example because in my case, queries to Node.child
(and Node.parent
) return instrumented lists and not ORM objects.
In any case, the code below will set up a simple directed acyclic disconnected graph that looks like this (where the direction is inferred to be from the higher row to the lower one):
a b c
\ / \ |
d e f
|\ /
g h
|
i
And what I'm looking for is some help writing a query that will give me all descendants of a node.
get_descendants(d)
should return g, h, i
get_descendants(b)
should return d, e, g, h, i
Example code:
from sqlalchemy.orm import aliased
from sqlalchemy import Column, ForeignKey, Integer, Table, Text
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import sessionmaker
engine = create_engine('sqlite:///:memory:', echo=True)
Session = sessionmaker(bind=engine)
session = Session()
Base = declarative_base()
association_table = Table('association_table', Base.metadata,
Column('parent_id', Integer, ForeignKey('node.id'), primary_key=True),
Column('child_id', Integer, ForeignKey('node.id'), primary_key=True))
class Node(Base):
__tablename__ = 'node'
id = Column(Integer, primary_key=True)
property_1 = Column(Text)
property_2 = Column(Integer)
# http://docs.sqlalchemy.org/en/latest/orm/join_conditions.html#self-referential-many-to-many-relationship
child = relationship('Node',
secondary=association_table,
primaryjoin=id==association_table.c.parent_id,
secondaryjoin=id==association_table.c.child_id,
backref='parent'
)
Base.metadata.create_all(engine)
a = Node(property_1='a', property_2=1)
b = Node(property_1='b', property_2=2)
c = Node(property_1='c', property_2=3)
d = Node(property_1='d', property_2=4)
e = Node(property_1='e', property_2=5)
f = Node(property_1='f', property_2=6)
g = Node(property_1='g', property_2=7)
h = Node(property_1='h', property_2=8)
i = Node(property_1='i', property_2=9)
session.add_all([a, b, c, d, e, f, g, h, i])
a.child.append(d)
b.child.append(d)
d.child.append(g)
d.child.append(h)
g.child.append(i)
b.child.append(e)
e.child.append(h)
c.child.append(f)
session.commit()
session.close()
The following, surprisingly simple, self-referential many-to-many recursive CTE query will return the desired results for finding all descendants of b
:
nodealias = aliased(Node)
descendants = session.query(Node)\
.filter(Node.id == b.id) \
.cte(name="descendants", recursive=True)
descendants = descendants.union(
session.query(nodealias)\
.join(descendants, nodealias.parent)
)
Testing with
for item in session.query(descendants):
print(item.property_1, item.property_2)
Yields:
b 2
d 4
e 5
g 7
h 8
i 9
Which is the correct list of b
and all of its descendants.
This example adds a convenient function to the Node
class for returning all descendants of an object, while also computing the path from itself to all of its descendants:
from sqlalchemy.orm import aliased
from sqlalchemy import Column, ForeignKey, Integer, Table, Text
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import sessionmaker
engine = create_engine('sqlite://', echo=True)
Session = sessionmaker(bind=engine)
session = Session()
Base = declarative_base()
association_table = Table('association_table', Base.metadata,
Column('parent_id', Integer, ForeignKey('node.id'), primary_key=True),
Column('child_id', Integer, ForeignKey('node.id'), primary_key=True))
class Node(Base):
__tablename__ = 'node'
id = Column(Integer, primary_key=True)
property_1 = Column(Text)
property_2 = Column(Integer)
# http://docs.sqlalchemy.org/en/latest/orm/join_conditions.html#self-referential-many-to-many-relationship
child = relationship('Node',
secondary=association_table,
primaryjoin=id==association_table.c.parent_id,
secondaryjoin=id==association_table.c.child_id,
backref='parent'
)
def descendant_nodes(self):
nodealias = aliased(Node)
descendants = session.query(Node.id, Node.property_1, (self.property_1 + '/' + Node.property_1).label('path')).filter(Node.parent.contains(self))\
.cte(recursive=True)
descendants = descendants.union(
session.query(nodealias.id, nodealias.property_1, (descendants.c.path + '/' + nodealias.property_1).label('path')).join(descendants, nodealias.parent)
)
return session.query(descendants.c.property_1, descendants.c.path).all()
Base.metadata.create_all(engine)
a = Node(property_1='a', property_2=1)
b = Node(property_1='b', property_2=2)
c = Node(property_1='c', property_2=3)
d = Node(property_1='d', property_2=4)
e = Node(property_1='e', property_2=5)
f = Node(property_1='f', property_2=6)
g = Node(property_1='g', property_2=7)
h = Node(property_1='h', property_2=8)
i = Node(property_1='i', property_2=9)
session.add_all([a, b, c, d, e, f, g, h, i])
a.child.append(d)
b.child.append(d)
d.child.append(g)
d.child.append(h)
g.child.append(i)
b.child.append(e)
e.child.append(h)
c.child.append(f)
e.child.append(i)
session.commit()
for item in b.descendant_nodes():
print(item)
session.close()
"""
Graph should be setup like this:
a b c
\ / \ |
d e f
|\ /|
g h |
+---+
i
"""
Output:
('d', 'b/d')
('e', 'b/e')
('g', 'b/d/g')
('h', 'b/d/h')
('h', 'b/e/h')
('i', 'b/e/i')
('i', 'b/d/g/i')