I’m using LangGraph to orchestrate workflows with a parent graph and a subgraph. The subgraph has its own checkpointing mechanism for state persistence, but I need it to retain memory of its own state even when the parent graph modifies shared fields.
Current Behavior:
If the parent graph sets a shared field (e.g., A = None), the subgraph’s checkpoint loses its previous state.
I want the subgraph to ignore parent-overrides for specific fields and persist its own state across invocations.
Desired Behavior:
Subgraph maintains its own state (e.g., A = 5) even if the parent sets A = None.
Checkpoints should store subgraph-specific state independently.
Code Example:
from langgraph.graph import Graph
# Parent Graph
parent = Graph()
parent.add_node("parent_node", lambda state: {"A": None}) # Overrides A
# Subgraph with checkpointing
subgraph = Graph(checkpoint=True)
subgraph.add_node("sub_node", lambda state: {"B": state.get("A", "default")})
# Connect them
parent.add_edge("parent_node", subgraph)
# Initial state (subgraph should remember A=5)
initial_state = {"A": 5, "B": "unset"}
# Run
final_state = parent.run(initial_state)
print(final_state) # Expect: {"A": None, "B": 5} (but B becomes "default")
What I’ve Tried:
subgraph.invoke(state, checkpoint={"A": state["A"]}) # Force subgraph memory (Fails: Parent’s A=None still overrides.)
State separation: Used distinct field names (e.g., subgraph_A vs. parent_A), but this feels hacky.
Global storage: Redis/database to store subgraph state, but adds external dependencies.
Question:
How can I configure LangGraph to let subgraphs preserve their own state (e.g., A=5) even when the parent graph modifies shared fields?
Look this, use pydantic.BaseModel
, define State
, and set __pydantic_fields_set__
from langgraph.constants import START, END
from langgraph.graph.state import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from pydantic import BaseModel
from typing import Optional
class MyState(BaseModel):
name: Optional[str] = None
age: Optional[int] = None
def node1(state: MyState) -> MyState:
print(state)
state.age = 10
state.name = "node1"
return state
g = StateGraph(MyState)
g.add_node("node1", node1)
g.add_edge(START, "node1")
g.add_edge("node1", END)
g = g.compile(checkpointer=MemorySaver())
g.invoke(MyState(name="Nick"), config={"configurable": {"thread_id": "1"}})
# output: name='Nick' age=None
m = MyState()
# !!! Important !!! Just add this line of code.
m.__pydantic_fields_set__ = set(MyState.model_fields.keys())
g.invoke(m, config={"configurable": {"thread_id": "1"}})
# output: name=None age=None