It was 1 a.m. when a colleague messaged me: our smart customer service bot had amnesia again — “The user just told us the return address, and in the very next turn the bot asked ‘What would you like to return?’”
I opened the logs and started eyeballing memory variables across dozens of conversation turns. An hour later I finally nailed it: the k parameter in ConversationBufferWindowMemory was set wrong, keeping only the single most recent exchange. At that moment, I thought: do we really have to test LLM memory by chatting line by line, by hand? This can’t go on.
Breaking down the problem
Once you give an LLM-powered application a Memory component, its behavior becomes subtle. Is memory being written at the right time? Is it keeping or forgetting information as expected? Under multi-turn conversations, memory types like summary, buffer, and entity stack on top of each other; a tiny misconfiguration leads to the model completely forgetting what was just said.
Manual validation usually means opening a terminal, entering a few rounds of conversation, and manually inspecting memory.load_memory_variables({}). Sometimes you even have to infer the memory state from the model’s replies. This approach has fatal flaws:
- Not repeatable – The inputs, order, and timing of a manual chat are nearly impossible to reproduce exactly. Intermittent bugs are impossible to catch.
- Low coverage – A human tester will only cover a few happy paths. Edge cases like a full memory buffer, token truncation, or multiple memory components working together rarely get tested.
- Slow feedback – Change one memory config, restart the service, chat through several turns, and visually compare results. A single regression test run easily takes 30+ minutes.
Why not just print() the memory variable somewhere inside the code? Because in a real-world Chain the calls are often asynchronous and streamed — the intermediate printed state may be inaccurate, and you still rely on a human to read the output. This can’t be integrated into CI. We need a way to turn memory state verification into an automated, repeatable, and quantifiable testing process.
Solution design
Core idea: use LangChain’s BaseCallbackHandler to automatically capture the memory state at the end of every LLM call, then write test cases with Pytest to assert on it.
Why Pytest instead of unittest? Pytest’s fixture system lets you easily build a Chain instance equipped with memory; parametrized tests are a natural fit for validating different memory configurations in bulk. Why not just call memory.load_memory_variables() directly? Because many memory updates happen inside the Chain’s internal logic (e.g., inside ConversationChain._call). Calling from the outside may give you an intermediate, inconsistent state. We need a mechanism that “peeks” at the memory right after the chain finishes execution. A custom callback can be hooked onto on_chain_end or on_llm_end, guaranteeing the correct timing.
Architecturally, we agreed on a test flow: each test case receives a pre-configured Chain (including a specified Memory) via a fixture; after executing chain.run(user_input), the test asserts on the memory variables exposed by the callback. Memory serialization uses the dict returned by load_memory_variables, which is sufficient for arbitrary comparisons.
Compared with other approaches:
- Using environment variables or global variables to stash memory state: pollutes the environment and breaks under concurrent tests.
- Directly accessing internal attributes like
memory.chat_memory.messages: highly invasive; different Memory subclasses have different implementations, making tests fragile. - This solution is based on the public interface, works with any Memory subclass, and can be extended freely.
Core implementation
1. Custom callback to capture memory state
This code solves the problem of “how to get a memory snapshot after a Chain run.” We write a MemoryCaptureCallback that collects memory variables on on_chain_end and stores them in a thread-safe list for test assertions.
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import BaseMemory
class MemoryCaptureCallback(BaseCallbackHandler):
"""在 Chain 结束时捕获 Memory 状态,供测试断言使用。"""
def __init__(self, memory: BaseMemory):
super().__init__()
self.memory = memory
# 每次运行的记忆快照列表,每个元素是一次 chain 调用结束后的状态
self.snapshots: List[Dict[str, Any]] = []
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
# 关键:必须在 chain 完全结束后读取,否则可能拿到未更新的数据
snapshot = self.memory.load_memory_variables({})
self.snapshots.append(dict(snapshot)) # 复制一份,防止后续变化影响
2. Pytest fixture to build a Chain with memory
This fixture solves the problem of “how each test case can quickly obtain a working conversation Chain.” Here we use ConversationBufferMemory as an example; you can replace it with any Memory.
import pytest
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
# 假设已有可 mock 的 LLM,实际测试中建议替换成轻量 mock 或测试专用模型
@pytest.fixture
def memory_capture_chain():
"""返回一个装配 MemoryCaptureCallback 的 ConversationChain 和捕获器实例。"""
memory = ConversationBufferMemory(return_messages=True)
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo") # 测试可用 mock
chain = ConversationChain(llm=llm, memory=memory, verbose=False)
# 把自定义 callback 加入 chain
capture = MemoryCaptureCallback(memory)
chain.callbacks = [capture] # 或者用 chain.verbose=False
Top comments (0)