pythonlangchainlanggraph

How to implement subgraph memory/persistence in LangGraph when parent and subgraph states diverge?


I’m using LangGraph to orchestrate workflows with a parent graph and a subgraph. The subgraph has its own checkpointing mechanism for state persistence, but I need it to retain memory of its own state even when the parent graph modifies shared fields.

Current Behavior:

If the parent graph sets a shared field (e.g., A = None), the subgraph’s checkpoint loses its previous state.

I want the subgraph to ignore parent-overrides for specific fields and persist its own state across invocations.

Desired Behavior:

Subgraph maintains its own state (e.g., A = 5) even if the parent sets A = None.

Checkpoints should store subgraph-specific state independently.

Code Example:

from langgraph.graph import Graph

# Parent Graph
parent = Graph()
parent.add_node("parent_node", lambda state: {"A": None})  # Overrides A

# Subgraph with checkpointing
subgraph = Graph(checkpoint=True)
subgraph.add_node("sub_node", lambda state: {"B": state.get("A", "default")})

# Connect them
parent.add_edge("parent_node", subgraph)

# Initial state (subgraph should remember A=5)
initial_state = {"A": 5, "B": "unset"}

# Run
final_state = parent.run(initial_state)
print(final_state)  # Expect: {"A": None, "B": 5} (but B becomes "default")

What I’ve Tried:

  1. Explicit checkpointing:
subgraph.invoke(state, checkpoint={"A": state["A"]})  # Force subgraph memory (Fails: Parent’s A=None still overrides.)
  1. State separation: Used distinct field names (e.g., subgraph_A vs. parent_A), but this feels hacky.

  2. Global storage: Redis/database to store subgraph state, but adds external dependencies.

Question:

How can I configure LangGraph to let subgraphs preserve their own state (e.g., A=5) even when the parent graph modifies shared fields?


Solution

  • Look this, use pydantic.BaseModel, define State, and set __pydantic_fields_set__

    from langgraph.constants import START, END
    from langgraph.graph.state import StateGraph
    from langgraph.checkpoint.memory import MemorySaver
    from pydantic import BaseModel
    from typing import  Optional
    
    
    class MyState(BaseModel):
    
        name: Optional[str] = None
        age: Optional[int] = None
    
    
    def node1(state: MyState) -> MyState:
        print(state)
        state.age = 10
        state.name = "node1"
        return state
    
    g = StateGraph(MyState)
    g.add_node("node1", node1)
    g.add_edge(START, "node1")
    g.add_edge("node1", END)
    g = g.compile(checkpointer=MemorySaver())
    
    g.invoke(MyState(name="Nick"), config={"configurable": {"thread_id": "1"}})
    # output: name='Nick' age=None
    m = MyState()
    
    # !!! Important  !!! Just add this line of code.
    m.__pydantic_fields_set__ = set(MyState.model_fields.keys())
    
    g.invoke(m, config={"configurable": {"thread_id": "1"}})
    # output: name=None age=None