LangChain is a popular framework for developing applications powered by large language models, providing components for working with LLMs through composable chains and agents. Like with microservices, when building production applications with LangChain, tracing and visualizing how the different components interact with each other becomes increasingly important. AWS X-Ray, being a native AWS service to monitor and analyze telemetry traces from the Lambda Functions, is the natural choice to use for tracing.
What is the problem?
In my previous articles "A Serverless Chatbot with LangChain & AWS Bedrock", and "Logging LangChain to AWS CloudWatch" I presented a solution for a serverless Chatbot with LangChain and AWS Bedrock. The solution implements all the features of conversation history, answering in the user language, custom context using RAG, model guardrails, structured outputs together with using LangChain callbacks for custom detailed logging to AWS CloudWatch Logs.
There are many tools and frameworks (e.g. LangSmith, Arize Phoenix, Langfuse, etc.) based on OpenTelemetry and OpenInference, to trace LLM applications and to do much more (evaluate LLM, evaluate RAG, run experiments, etc.). But I wanted to force myself to see what I could do for tracing with native AWS tools. So, I built a custom solution with the native AWS X-Ray tooling. But to do that, I needed to create X-Ray trace subsegments for every action performed by any type of LangChain Runnable
. There are two challenges associated with that.
The first one is that by default when enabling AWS X-Ray tracing on Lambda functions, it only traces call to AWS services. To capture traces to external services (e.g. an HTTP request to a public REST API) the AWS X-Ray SDK for Python for example, uses Python wrapt library to "patch" specific class methods to generate a trace subsegment before executing the method. An example of that is creating a trace subsegment when the request()
method (essentially making an HTTP call) of the Session
class of Python requests
library is called, as shown here. As X-Ray SDK covers only the most common Python libraries, I have to patch the LangChain library myself to intercept any execution of invoke()
or ainvoke()
methods of every possible type of LangChain Runnable
.
The second challenge relates to the limitations of the X-Ray SDK in AWS Lambda Functions.
- The X-Ray SDK "is configured to automatically create a placeholder
facade
segment when it detects it is running in Lambda". - You can’t create your own trace segment in a Lambda Function. Only subsegments.
- The SDK creates an
AWS::Lambda::Function
subsegment and further subsegments are attached to it. - When using threads, every thread trace root is the Lambda
facade
segment, not the subsegment that created the thread.
As LangChain tries to run its actions in parallel threads, or if you specify yourself some tasks to run in parallel, instead of seeing trace subsegments as:
AWS::Lambda::Function
└── RunnableParallel
├── RunnableLambda
└── RunnableLambda
What you get is:
├── AWS::Lambda::Function
│ └── RunnableParallel
├── RunnableLambda
└── RunnableLambda
This is not what I would expect and want, as it does not accurately represent the inner working of LangChain.
The Solution
In order to achieve my goals of correctly tracing LangChain Runnable interactions, I then need to fix those two issues by:
- Patching myself LangChain
Runnable
classes, - Resetting the trace context to the parent subsegment for all LangChain
RunnableParallel
classes creating threads.
Patching LangChain
This is "easily" done by following the pattern used by the AWS X-Ray SDK for the requests
library as shared above. The problems, are that there are many classes to cover, they will change, and I don’t want to repeat the same code for every Runnable class. There are also 2 types of classes that I must patch: Runnable
and RunnableSerializable
.
To achieve this, I do the following:
- Create some functions to recursively list all imported child classes of
Runnable
andRunnableSerializable
and remove duplicates. - Loop over the resulting list and wrap all
invoke()
method of those classes
import wrapt
import threading
from aws_xray_sdk.core import xray_recorder
from langchain_core.runnables import Runnable, RunnableSerializable
def dedup_classes_by_origin(classes):
"""Return a set of (module, base_name) keys for unique classes."""
seen = set()
for cls in classes:
module = getattr(cls, "__module__", "")
name = getattr(cls, "__name__", str(cls))
base_name = name.split("[")[0]
key = (module, base_name)
seen.add(key)
return seen
def all_subclasses(cls):
"""Recursively find all subclasses of a class."""
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)]
)
def patch_langchain_runnables():
# Get all classes from langchain_core.runnables.RunnableSerializable
unique_serializable_classes = dedup_classes_by_origin(
all_subclasses(RunnableSerializable)
)
unique_runnable_classes = dedup_classes_by_origin(all_subclasses(Runnable))
# Combine both lists to ensure we cover all RunnableSerializable and Runnable classes
unique_classes = unique_serializable_classes | unique_runnable_classes
for module, class_name in unique_classes:
# Patch the invoke method of each class
wrapt.wrap_function_wrapper(module, f"{class_name}.invoke", traced_invoke)
The patch_langchain_runnables()
has then to be called at the beginning of the Lambda Function code immediately after the imports. For example:
import boto3
from aws_lambda_powertools.utilities.typing import LambdaContext
from langchain_core.runnables import (
RunnableParallel,
RunnableLambda,
RunnablePassthrough,
)
patch_langchain_runnables()
def lambda_handler(event: dict, context: LambdaContext):
print("Do something smart with LangChain here.")
Fixing Subsegment Parents
Now that the first challenge is solved and the invoke()
method of all the child classes of Runnable
and RunnableSerializable
are wrapped by the traced_invoke()
method, I need to define it and generate a subsegment before calling the initial class method while ensuring proper subsegment lineage.
For classes not using threads, a simple execution of the class method inside a subsegment would work.
def traced_invoke(wrapped, instance, args, kwargs):
with xray_recorder.in_subsegment(runnable_name):
result = wrapped(*args, **kwargs)
return result
But as discussed previously, for Runnable
classes which will be invoked in threads by the RunnableParrallel
class, I need to forcibly overwrite their parent subsegment from the facade
segment to the RunnableParrallel
segment. I do that with the following approach:
- Capture the new segment (entity)
- Monkey patch (after taking a backup of it) the
threading.Thread.run()
method by my own, which sets the trace entity to the one I just captured before executing the actualThread.run()
method - Execute the original class
invoke()
method - Restore the original
threading.Thread.run()
method
Thus, when the class original invoke()
method will run a new thread, it will first overwrite the trace context to the parent’s subsegment.
def traced_invoke(wrapped, instance, args, kwargs):
"""
A wrapper function to trace the invocation of Runnable classes using AWS X-Ray.
This function is used to create a subsegment in the X-Ray trace for the Runnable invocation.
Args:
wrapped: The original invoke method of the Runnable class.
instance: The instance of the Runnable class being invoked.
args: Positional arguments passed to the invoke method.
kwargs: Keyword arguments passed to the invoke method.
Returns:
The result of the wrapped invoke method.
"""
runnable_name = type(instance).__name__ if instance else wrapped.__name__
with xray_recorder.in_subsegment(runnable_name):
if runnable_name.startswith("RunnableParallel"):
# Get the parent entity for the child threads
parent_entity = xray_recorder.get_trace_entity()
# Back up the original threading.Thread.run method
orig_thread_run = threading.Thread.run
# Monkey-patch the threading.Thread.run method to set the parent entity for new threads
def run_with_entity(self, *a, **k):
xray_recorder.set_trace_entity(parent_entity)
return orig_thread_run(self, *a, **k)
# Replace the threading.Thread.run method with our patched version
threading.Thread.run = run_with_entity
try:
result = wrapped(*args, **kwargs)
finally:
# Once done, restore the original thread run method
threading.Thread.run = orig_thread_run
else:
result = wrapped(*args, **kwargs)
return result
Note: following the same approach, the same can be achieved for LangChain asynchronous ainvoke()
method, but not shown here for simplification.
Patching with wrapt vs Monkey Patching
As you saw, I use two different approaches to patch different class methods:
- The
wrapt
library for LangChain Runnable classesinvoke()
method - Monkey patching for threading
Thread.run()
method
Why?
In the case of LangChain, I want to patch the invoke()
method for all their executions and for the entire duration of the Lambda Function. The wrapt
library is designed to do exactly that.
In the case of the threads, I just want to patch the Thread.run()
method in the temporary context of RunnableParrallel.invoke()
method. The wrapt
library does not provide a built-in way to temporarily patch and then restore a method at runtime in a specific block of code.
Be careful with monkey patching though, as you are changing behaviors of methods in ways that other developers might not be aware of, and updates in the inner working of the method you are patching might interfere with your patching method. But it feels reasonable in this specific context with the actions performed by the patching method.
The Results
The LangChain setup described in my first article "A Serverless Chatbot with LangChain & AWS Bedrock" results in the following AWS X-Ray trace properly showing the different step in order:
- The initial steps of the chain to retrieve references from the knowledge base, retrieve the conversation history and detect the language run in parallel
- The prompt being generated based on all those inputs
- The LLM model called with that prompt using Bedrock
- The tool to structure the model output based on my Pydantic definition called
Lessons Learned
Building upon my previous articles on serverless LangChain applications and logging LangChain to AWS CloudWatch Logs, this tracing implementation with AWS X-Ray has revealed additional insights worth sharing:
- Adapting existing tools for new use cases can be challenging: While AWS X-Ray SDK wasn't designed specifically for tracing AI frameworks like LangChain, this project demonstrated the potential to extend its capabilities creatively.
- Deep understanding of both AWS Lambda and AWS X-Ray internals proved crucial: Knowing how Lambda manages threads and how the X-Ray SDK patches common libraries is key to developing an effective custom tracing solution.
- Performance implications of custom tracing solutions should be considered: While not explicitly discussed, it's important to note that any custom tracing implementation may impact the performance of the system and should be carefully monitored and optimized.
While this implementation successfully integrates LangChain tracing with AWS X-Ray as subsegments within the Lambda trace, the natural progression is to create a separate trace map specifically for LangChain. A visual representation of the interactions between LangChain components, would allow for more granular analysis and optimization of AI-driven applications.
Top comments (0)