Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""OpenTelemetry Langchain instrumentation"""

import logging
import weakref
from typing import Collection

from opentelemetry import context as context_api
Expand Down Expand Up @@ -35,6 +36,12 @@

_instruments = ("langchain-core > 0.1.0", )

_ALL_HOOK_NAMES = [
"before_model", "after_model", "before_agent", "after_agent",
"abefore_model", "aafter_model", "abefore_agent", "aafter_agent",
]
_patched_middleware_instances: weakref.WeakSet = weakref.WeakSet()


class LangchainInstrumentor(BaseInstrumentor):
"""An instrumentor for Langchain SDK."""
Expand Down Expand Up @@ -271,30 +278,54 @@ def _wrap_agent_factories(self, tracer):
logger.debug("Failed to wrap langchain.agents.create_agent: %s", e)

def _wrap_middleware_hooks(self, tracer):
"""Wrap AgentMiddleware hook methods for instrumentation."""
# Sync hooks
"""Wrap AgentMiddleware hook methods for instrumentation.

Uses instance-level wrapping via __init__ instead of class-level
wrapt patching. Class-level wrapping with wrapt replaces base class
methods with FunctionWrapper descriptors, which breaks Python identity
checks (e.g. ``m.__class__.before_agent is not AgentMiddleware.before_agent``)
used by LangGraph's create_agent to determine which hooks are overridden.

Limitation: subclasses that override ``__init__`` without calling
``super().__init__()`` will not have their hooks instrumented. This is
an acceptable tradeoff — such subclasses violate standard Python
conventions, and the alternative (class-level wrapping) breaks
LangGraph graph construction entirely.
"""
sync_hooks = ["before_model", "after_model", "before_agent", "after_agent"]
for hook_name in sync_hooks:
try:
wrap_function_wrapper(
module="langchain.agents.middleware.types",
name=f"AgentMiddleware.{hook_name}",
wrapper=create_middleware_hook_wrapper(tracer, hook_name),
)
except Exception as e:
logger.debug("Failed to wrap AgentMiddleware.%s: %s", hook_name, e)

# Async hooks
async_hooks = ["abefore_model", "aafter_model", "abefore_agent", "aafter_agent"]
for hook_name in async_hooks:
try:
wrap_function_wrapper(
module="langchain.agents.middleware.types",
name=f"AgentMiddleware.{hook_name}",
wrapper=create_async_middleware_hook_wrapper(tracer, hook_name),
)
except Exception as e:
logger.debug("Failed to wrap AgentMiddleware.%s: %s", hook_name, e)

sync_wrappers = {h: create_middleware_hook_wrapper(tracer, h) for h in sync_hooks}
async_wrappers = {h: create_async_middleware_hook_wrapper(tracer, h) for h in async_hooks}

def _middleware_init_wrapper(wrapped, instance, args, kwargs):
wrapped(*args, **kwargs)
for hook_name, wrapper_fn in sync_wrappers.items():
original = getattr(instance, hook_name, None)
if original is not None:
def make_bound(orig, wfn):
def instrumented(*a, **kw):
return wfn(orig, instance, a, kw)
return instrumented
setattr(instance, hook_name, make_bound(original, wrapper_fn))
for hook_name, wrapper_fn in async_wrappers.items():
original = getattr(instance, hook_name, None)
if original is not None:
def make_async_bound(orig, wfn):
async def instrumented(*a, **kw):
return await wfn(orig, instance, a, kw)
return instrumented
setattr(instance, hook_name, make_async_bound(original, wrapper_fn))
Comment thread
coderabbitai[bot] marked this conversation as resolved.
_patched_middleware_instances.add(instance)

try:
wrap_function_wrapper(
module="langchain.agents.middleware.types",
name="AgentMiddleware.__init__",
wrapper=_middleware_init_wrapper,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
except Exception as e:
logger.debug("Failed to wrap AgentMiddleware.__init__: %s", e)

def _uninstrument(self, **kwargs):
unwrap("langchain_core.callbacks", "BaseCallbackManager.__init__")
Expand All @@ -313,13 +344,18 @@ def _uninstrument(self, **kwargs):

# Unwrap AgentMiddleware hooks
if is_package_available("langchain"):
sync_hooks = ["before_model", "after_model", "before_agent", "after_agent"]
async_hooks = ["abefore_model", "aafter_model", "abefore_agent", "aafter_agent"]
for hook_name in sync_hooks + async_hooks:
try:
unwrap("langchain.agents.middleware.types", f"AgentMiddleware.{hook_name}")
except Exception:
pass
# Remove instance-level hook patches from existing instances
for instance in list(_patched_middleware_instances):
for hook_name in _ALL_HOOK_NAMES:
try:
delattr(instance, hook_name)
except AttributeError:
pass
_patched_middleware_instances.clear()
try:
unwrap("langchain.agents.middleware.types", "AgentMiddleware.__init__")
except Exception:
pass

# Unwrap LangGraph agent factories (both actual module and re-export)
if is_package_available("langgraph"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -766,17 +766,16 @@ async def abefore_model(self, state, runtime):


def test_middleware_super_call_succeeds_despite_outer_failure(instrument_legacy, span_exporter):
"""Test that wrapper records super() call as success even when outer method raises.
"""Test that wrapper records the full method call status correctly.

The instrumentation wraps AgentMiddleware.before_model (the base class method).
When a subclass calls super().before_model(), that wrapped call succeeds.
Even if the subclass's own before_model() then raises an exception, the span
for the super() call correctly records status="success".
With instance-level wrapping, the wrapper is on the instance's before_model
method. When the subclass raises an exception, the span correctly records
status="failure" since the wrapper encompasses the entire method call.
"""

class FailingMiddleware(AgentMiddleware):
def before_model(self, state, runtime):
# Call super first to trigger the wrapper, then fail
# Call super first, then fail
super().before_model(state, runtime)
raise ValueError("Intentional failure")

Expand All @@ -787,11 +786,10 @@ def before_model(self, state, runtime):
pass # Expected

spans = span_exporter.get_finished_spans()
# The wrapper is on AgentMiddleware.before_model, so look for that
middleware_spans = [s for s in spans if "before_model" in s.name]

# Should have at least one span from calling super().before_model()
# Instance-level wrapper creates a span for the full before_model call
assert len(middleware_spans) >= 1
# The span from super() call should succeed (before the ValueError is raised)
middleware_span = middleware_spans[0]
assert middleware_span.attributes[SpanAttributes.GEN_AI_TASK_STATUS] == "success"
# The method raised, so the span correctly records failure
assert middleware_span.attributes[SpanAttributes.GEN_AI_TASK_STATUS] == "failure"
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Test that middleware hook instrumentation preserves base class identity checks.

LangGraph's create_agent uses identity checks like:
m.__class__.before_agent is not AgentMiddleware.before_agent
to decide whether a middleware overrides a hook. Class-level wrapping with wrapt
breaks this by replacing base class methods with FunctionWrapper descriptors.
"""

import pytest
from langchain.agents.middleware.types import AgentMiddleware
from opentelemetry.instrumentation.langchain import LangchainInstrumentor
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter


ALL_HOOKS = [
"before_model", "after_model", "before_agent", "after_agent",
"abefore_model", "aafter_model", "abefore_agent", "aafter_agent",
]


class MyMiddleware(AgentMiddleware):
"""Subclass that only overrides before_agent."""

def before_agent(self, state, runtime):
return {"custom": True}


@pytest.fixture()
def _instrument():
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))

instrumentor = LangchainInstrumentor()
instrumentor.instrument(tracer_provider=provider)
yield instrumentor, exporter
instrumentor.uninstrument()


def test_base_class_identity_preserved_after_instrumentation(_instrument):
"""Base class methods must remain identical via `is` after instrumentation.

This is the exact check LangGraph's factory.py uses to decide whether
to add graph nodes for middleware hooks.
"""
for hook_name in ALL_HOOKS:
base_method = getattr(AgentMiddleware, hook_name)
sub_method = getattr(MyMiddleware, hook_name)
if hook_name == "before_agent":
# MyMiddleware overrides this — should NOT be identical
assert sub_method is not base_method, (
f"MyMiddleware.{hook_name} should differ from base"
)
else:
# Not overridden — must be identical
assert sub_method is base_method, (
f"MyMiddleware.{hook_name} should be identical to "
f"AgentMiddleware.{hook_name} but is not — "
f"class-level wrapping likely broke identity"
)


def test_instance_hooks_are_instrumented(_instrument):
"""Instance-level hooks should be wrapped for tracing after construction."""
m = MyMiddleware()
for hook_name in ALL_HOOKS:
assert hook_name in m.__dict__, (
f"{hook_name} should be in instance __dict__ (instrumented)"
)


def test_uninstrument_removes_instance_patches(_instrument):
"""After uninstrument(), pre-existing instances must stop emitting spans."""
instrumentor, exporter = _instrument

m = MyMiddleware()
# Verify hooks are patched
assert "before_model" in m.__dict__, "Hook should be in instance __dict__"

# Call a hook — should produce a span
m.before_model({}, None)
spans_before = exporter.get_finished_spans()
assert len(spans_before) == 1

exporter.clear()

# Uninstrument — should clean up instance patches
instrumentor.uninstrument()

# Instance __dict__ should no longer shadow the hooks
assert "before_model" not in m.__dict__, (
"Hook should be removed from instance __dict__ after uninstrument"
)

# Calling the hook now goes to the unpatched class method — no span
m.before_model({}, None)
spans_after = exporter.get_finished_spans()
assert len(spans_after) == 0, (
"No spans should be emitted after uninstrument()"
)