DEV Community

Cover image for Tracing LangChain with AWS X-Ray
Matthieu Lienart
Matthieu Lienart

Posted on

Tracing LangChain with AWS X-Ray

LangChain is a popular framework for developing applications powered by large language models, providing components for working with LLMs through composable chains and agents. Like with microservices, when building production applications with LangChain, tracing and visualizing how the different components interact with each other becomes increasingly important. AWS X-Ray, being a native AWS service to monitor and analyze telemetry traces from the Lambda Functions, is the natural choice to use for tracing.

What is the problem?

In my previous articles "A Serverless Chatbot with LangChain & AWS Bedrock", and "Logging LangChain to AWS CloudWatch" I presented a solution for a serverless Chatbot with LangChain and AWS Bedrock. The solution implements all the features of conversation history, answering in the user language, custom context using RAG, model guardrails, structured outputs together with using LangChain callbacks for custom detailed logging to AWS CloudWatch Logs.

There are many tools and frameworks (e.g. LangSmith, Arize Phoenix, Langfuse, etc.) based on OpenTelemetry and OpenInference, to trace LLM applications and to do much more (evaluate LLM, evaluate RAG, run experiments, etc.). But I wanted to force myself to see what I could do for tracing with native AWS tools. So, I built a custom solution with the native AWS X-Ray tooling. But to do that, I needed to create X-Ray trace subsegments for every action performed by any type of LangChain Runnable. There are two challenges associated with that.

The first one is that by default when enabling AWS X-Ray tracing on Lambda functions, it only traces call to AWS services. To capture traces to external services (e.g. an HTTP request to a public REST API) the AWS X-Ray SDK for Python for example, uses Python wrapt library to "patch" specific class methods to generate a trace subsegment before executing the method. An example of that is creating a trace subsegment when the request() method (essentially making an HTTP call) of the Session class of Python requests library is called, as shown here. As X-Ray SDK covers only the most common Python libraries, I have to patch the LangChain library myself to intercept any execution of invoke() or ainvoke() methods of every possible type of LangChain Runnable.

The second challenge relates to the limitations of the X-Ray SDK in AWS Lambda Functions.

  • The X-Ray SDK "is configured to automatically create a placeholder facade segment when it detects it is running in Lambda".
  • You can’t create your own trace segment in a Lambda Function. Only subsegments.
  • The SDK creates an AWS::Lambda::Function subsegment and further subsegments are attached to it.
  • When using threads, every thread trace root is the Lambda facade segment, not the subsegment that created the thread.

As LangChain tries to run its actions in parallel threads, or if you specify yourself some tasks to run in parallel, instead of seeing trace subsegments as:

AWS::Lambda::Function
└── RunnableParallel
    ├── RunnableLambda
    └── RunnableLambda
Enter fullscreen mode Exit fullscreen mode

What you get is:

├── AWS::Lambda::Function
│   └── RunnableParallel
├── RunnableLambda
└── RunnableLambda
Enter fullscreen mode Exit fullscreen mode

This is not what I would expect and want, as it does not accurately represent the inner working of LangChain.

The Solution

In order to achieve my goals of correctly tracing LangChain Runnable interactions, I then need to fix those two issues by:

  • Patching myself LangChain Runnable classes,
  • Resetting the trace context to the parent subsegment for all LangChain RunnableParallel classes creating threads.

Patching LangChain

This is "easily" done by following the pattern used by the AWS X-Ray SDK for the requests library as shared above. The problems, are that there are many classes to cover, they will change, and I don’t want to repeat the same code for every Runnable class. There are also 2 types of classes that I must patch: Runnable and RunnableSerializable.

To achieve this, I do the following:

  1. Create some functions to recursively list all imported child classes of Runnable and RunnableSerializable and remove duplicates.
  2. Loop over the resulting list and wrap all invoke() method of those classes
import wrapt
import threading
from aws_xray_sdk.core import xray_recorder
from langchain_core.runnables import Runnable, RunnableSerializable

def dedup_classes_by_origin(classes):
    """Return a set of (module, base_name) keys for unique classes."""
    seen = set()
    for cls in classes:
        module = getattr(cls, "__module__", "")
        name = getattr(cls, "__name__", str(cls))
        base_name = name.split("[")[0]
        key = (module, base_name)
        seen.add(key)
    return seen

def all_subclasses(cls):
    """Recursively find all subclasses of a class."""
    return set(cls.__subclasses__()).union(
        [s for c in cls.__subclasses__() for s in all_subclasses(c)]
    )

def patch_langchain_runnables():
    # Get all classes from langchain_core.runnables.RunnableSerializable
    unique_serializable_classes = dedup_classes_by_origin(
        all_subclasses(RunnableSerializable)
    )
    unique_runnable_classes = dedup_classes_by_origin(all_subclasses(Runnable))
    # Combine both lists to ensure we cover all RunnableSerializable and Runnable classes
    unique_classes = unique_serializable_classes | unique_runnable_classes
    for module, class_name in unique_classes:
        # Patch the invoke method of each class
        wrapt.wrap_function_wrapper(module, f"{class_name}.invoke", traced_invoke)
Enter fullscreen mode Exit fullscreen mode

The patch_langchain_runnables() has then to be called at the beginning of the Lambda Function code immediately after the imports. For example:

import boto3
from aws_lambda_powertools.utilities.typing import LambdaContext
from langchain_core.runnables import (
    RunnableParallel,
    RunnableLambda,
    RunnablePassthrough,
)
patch_langchain_runnables()

def lambda_handler(event: dict, context: LambdaContext):
    print("Do something smart with LangChain here.")
Enter fullscreen mode Exit fullscreen mode

Fixing Subsegment Parents

Now that the first challenge is solved and the invoke() method of all the child classes of Runnable and RunnableSerializable are wrapped by the traced_invoke() method, I need to define it and generate a subsegment before calling the initial class method while ensuring proper subsegment lineage.

For classes not using threads, a simple execution of the class method inside a subsegment would work.

def traced_invoke(wrapped, instance, args, kwargs):
    with xray_recorder.in_subsegment(runnable_name):
        result = wrapped(*args, **kwargs)
    return result
Enter fullscreen mode Exit fullscreen mode

But as discussed previously, for Runnable classes which will be invoked in threads by the RunnableParrallel class, I need to forcibly overwrite their parent subsegment from the facade segment to the RunnableParrallel segment. I do that with the following approach:

  1. Capture the new segment (entity)
  2. Monkey patch (after taking a backup of it) the threading.Thread.run() method by my own, which sets the trace entity to the one I just captured before executing the actual Thread.run() method
  3. Execute the original class invoke() method
  4. Restore the original threading.Thread.run() method

Thus, when the class original invoke() method will run a new thread, it will first overwrite the trace context to the parent’s subsegment.

def traced_invoke(wrapped, instance, args, kwargs):
    """
    A wrapper function to trace the invocation of Runnable classes using AWS X-Ray.
    This function is used to create a subsegment in the X-Ray trace for the Runnable invocation.

    Args:
        wrapped: The original invoke method of the Runnable class.
        instance: The instance of the Runnable class being invoked.
        args: Positional arguments passed to the invoke method.
        kwargs: Keyword arguments passed to the invoke method.

    Returns:
        The result of the wrapped invoke method.
    """
    runnable_name = type(instance).__name__ if instance else wrapped.__name__
    with xray_recorder.in_subsegment(runnable_name):
        if runnable_name.startswith("RunnableParallel"):
            # Get the parent entity for the child threads
            parent_entity = xray_recorder.get_trace_entity()
            # Back up the original threading.Thread.run method
            orig_thread_run = threading.Thread.run

            # Monkey-patch the threading.Thread.run method to set the parent entity for new threads
            def run_with_entity(self, *a, **k):
                xray_recorder.set_trace_entity(parent_entity)
                return orig_thread_run(self, *a, **k)

            # Replace the threading.Thread.run method with our patched version
            threading.Thread.run = run_with_entity
            try:
                result = wrapped(*args, **kwargs)
            finally:
                # Once done, restore the original thread run method
                threading.Thread.run = orig_thread_run
        else:
            result = wrapped(*args, **kwargs)
        return result
Enter fullscreen mode Exit fullscreen mode

Note: following the same approach, the same can be achieved for LangChain asynchronous ainvoke() method, but not shown here for simplification.

Patching with wrapt vs Monkey Patching

As you saw, I use two different approaches to patch different class methods:

  • The wrapt library for LangChain Runnable classes invoke() method
  • Monkey patching for threading Thread.run() method

Why?

In the case of LangChain, I want to patch the invoke() method for all their executions and for the entire duration of the Lambda Function. The wrapt library is designed to do exactly that.

In the case of the threads, I just want to patch the Thread.run() method in the temporary context of RunnableParrallel.invoke() method. The wrapt library does not provide a built-in way to temporarily patch and then restore a method at runtime in a specific block of code.

Be careful with monkey patching though, as you are changing behaviors of methods in ways that other developers might not be aware of, and updates in the inner working of the method you are patching might interfere with your patching method. But it feels reasonable in this specific context with the actions performed by the patching method.

The Results

The LangChain setup described in my first article "A Serverless Chatbot with LangChain & AWS Bedrock" results in the following AWS X-Ray trace properly showing the different step in order:

  1. The initial steps of the chain to retrieve references from the knowledge base, retrieve the conversation history and detect the language run in parallel
  2. The prompt being generated based on all those inputs
  3. The LLM model called with that prompt using Bedrock
  4. The tool to structure the model output based on my Pydantic definition called

Lessons Learned

Building upon my previous articles on serverless LangChain applications and logging LangChain to AWS CloudWatch Logs, this tracing implementation with AWS X-Ray has revealed additional insights worth sharing:

  • Adapting existing tools for new use cases can be challenging: While AWS X-Ray SDK wasn't designed specifically for tracing AI frameworks like LangChain, this project demonstrated the potential to extend its capabilities creatively.
  • Deep understanding of both AWS Lambda and AWS X-Ray internals proved crucial: Knowing how Lambda manages threads and how the X-Ray SDK patches common libraries is key to developing an effective custom tracing solution.
  • Performance implications of custom tracing solutions should be considered: While not explicitly discussed, it's important to note that any custom tracing implementation may impact the performance of the system and should be carefully monitored and optimized.

While this implementation successfully integrates LangChain tracing with AWS X-Ray as subsegments within the Lambda trace, the natural progression is to create a separate trace map specifically for LangChain. A visual representation of the interactions between LangChain components, would allow for more granular analysis and optimization of AI-driven applications.

Top comments (0)