Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions netra/exporters/filtering_span_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from netra.exporters.utils import add_blocked_trace_id, get_trace_id, is_trace_id_blocked, is_trial_blocked
from netra.processors.local_filtering_span_processor import (
BLOCKED_LOCAL_PARENT_MAP,
blocked_local_parent_map_snapshot,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -113,12 +113,7 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:

# Merge with registry of locally blocked spans captured by processor to handle
# cases where children export before their blocked parent (SimpleSpanProcessor)
merged_map: Dict[Any, Any] = {}
try:
if BLOCKED_LOCAL_PARENT_MAP:
merged_map.update(BLOCKED_LOCAL_PARENT_MAP)
except Exception:
pass
merged_map: Dict[Any, Any] = blocked_local_parent_map_snapshot()
merged_map.update(blocked_parent_map)

if merged_map:
Expand Down
47 changes: 41 additions & 6 deletions netra/processors/local_filtering_span_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import logging
import threading
from contextlib import contextmanager
from typing import List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence

from opentelemetry import baggage
from opentelemetry import context as otel_context
Expand All @@ -15,9 +16,43 @@
# Attribute key to copy resolved local blocked patterns onto each span
_LOCAL_BLOCKED_SPANS_ATTR_KEY = "netra.local_blocked_spans"

# Registry of locally blocked spans: span_id -> parent_context
# This lets exporters reparent children reliably even when children export before parents
BLOCKED_LOCAL_PARENT_MAP: dict[object, object] = {}
# Registry of locally blocked spans: span_id -> parent_context.
# This lets exporters reparent children reliably even when children export
# before parents. All access must go through the accessor functions below
# to ensure thread-safety.
_blocked_local_parent_map: Dict[Any, Any] = {}
_blocked_local_parent_lock = threading.Lock()


def blocked_local_parent_map_put(span_id: Any, parent_context: Any) -> None:
"""Register a locally-blocked span's parent context.

Args:
span_id: The span ID of the blocked span.
parent_context: The parent ``SpanContext`` to reparent children to.
"""
with _blocked_local_parent_lock:
_blocked_local_parent_map[span_id] = parent_context


def blocked_local_parent_map_pop(span_id: Any) -> None:
"""Remove a span entry from the blocked-parent registry.

Args:
span_id: The span ID to remove.
"""
with _blocked_local_parent_lock:
_blocked_local_parent_map.pop(span_id, None)


def blocked_local_parent_map_snapshot() -> Dict[Any, Any]:
"""Return a shallow copy of the blocked-parent registry.

Returns:
A dict copy safe to iterate without holding the lock.
"""
with _blocked_local_parent_lock:
return dict(_blocked_local_parent_map)


class LocalFilteringSpanProcessor(SpanProcessor): # type: ignore[misc]
Expand Down Expand Up @@ -62,7 +97,7 @@ def on_start(self, span: trace.Span, parent_context: Optional[otel_context.Conte
parent_span.get_span_context() if hasattr(parent_span, "get_span_context") else None
)
if span_id is not None and parent_span_context is not None:
BLOCKED_LOCAL_PARENT_MAP[span_id] = parent_span_context
blocked_local_parent_map_put(span_id, parent_span_context)
# Mark on the span for visibility/debugging
try:
span.set_attribute("netra.local_blocked", True)
Expand All @@ -87,7 +122,7 @@ def on_end(self, span: trace.Span) -> None: # noqa: D401
ctx = getattr(span, "context", None)
span_id = getattr(ctx, "span_id", None) if ctx else None
if span_id is not None:
BLOCKED_LOCAL_PARENT_MAP.pop(span_id, None)
blocked_local_parent_map_pop(span_id)
except Exception:
pass
return
Expand Down
159 changes: 86 additions & 73 deletions netra/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import threading
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union
Expand All @@ -20,7 +21,14 @@ class ConversationType(str, Enum):


class SessionManager:
"""Manages session and user context for applications."""
"""Manages session and user context for applications.

All mutable class-level state is protected by ``_lock`` so that
concurrent threads (or ``asyncio.to_thread`` calls) cannot corrupt
the internal stacks and registries.
"""

_lock = threading.Lock()

# Class variable to track the current span
_current_span: Optional[trace.Span] = None
Expand All @@ -46,7 +54,8 @@ def set_current_span(cls, span: Optional[trace.Span]) -> None:
Args:
span: The current span to store
"""
cls._current_span = span
with cls._lock:
cls._current_span = span

@classmethod
def get_current_span(cls) -> Optional[trace.Span]:
Expand All @@ -56,7 +65,8 @@ def get_current_span(cls) -> Optional[trace.Span]:
Returns:
The stored current span or None if not set
"""
return cls._current_span
with cls._lock:
return cls._current_span

@classmethod
def register_span(cls, name: str, span: trace.Span) -> None:
Expand All @@ -68,13 +78,13 @@ def register_span(cls, name: str, span: trace.Span) -> None:
span: The span to register
"""
try:
stack = cls._spans_by_name.get(name)
if stack is None:
cls._spans_by_name[name] = [span]
else:
stack.append(span)
# Track globally as active
cls._active_spans.append(span)
with cls._lock:
stack = cls._spans_by_name.get(name)
if stack is None:
cls._spans_by_name[name] = [span]
else:
stack.append(span)
cls._active_spans.append(span)
except Exception:
logger.exception("Failed to register span '%s'", name)

Expand All @@ -88,21 +98,20 @@ def unregister_span(cls, name: str, span: trace.Span) -> None:
span: The span to unregister
"""
try:
stack = cls._spans_by_name.get(name)
if not stack:
return
# Remove the last matching instance (normal case)
for i in range(len(stack) - 1, -1, -1):
if stack[i] is span:
stack.pop(i)
break
if not stack:
cls._spans_by_name.pop(name, None)
# Also remove from global active list (remove last matching instance)
for i in range(len(cls._active_spans) - 1, -1, -1):
if cls._active_spans[i] is span:
cls._active_spans.pop(i)
break
with cls._lock:
stack = cls._spans_by_name.get(name)
if not stack:
return
for i in range(len(stack) - 1, -1, -1):
if stack[i] is span:
stack.pop(i)
break
if not stack:
cls._spans_by_name.pop(name, None)
for i in range(len(cls._active_spans) - 1, -1, -1):
if cls._active_spans[i] is span:
cls._active_spans.pop(i)
break
except Exception:
logger.exception("Failed to unregister span '%s'", name)

Expand Down Expand Up @@ -131,10 +140,11 @@ def get_span_by_name(cls, name: str) -> Optional[trace.Span]:
Returns:
The most recently registered span with the given name, or None if not found
"""
stack = cls._spans_by_name.get(name)
if stack:
return stack[-1]
return None
with cls._lock:
stack = cls._spans_by_name.get(name)
if stack:
return stack[-1]
return None

@classmethod
def push_entity(cls, entity_type: str, entity_name: str) -> None:
Expand All @@ -145,14 +155,15 @@ def push_entity(cls, entity_type: str, entity_name: str) -> None:
entity_type: Type of entity (workflow, task, agent, span)
entity_name: Name of the entity
"""
if entity_type == "workflow":
cls._workflow_stack.append(entity_name)
elif entity_type == "task":
cls._task_stack.append(entity_name)
elif entity_type == "agent":
cls._agent_stack.append(entity_name)
elif entity_type == "span":
cls._span_stack.append(entity_name)
with cls._lock:
if entity_type == "workflow":
cls._workflow_stack.append(entity_name)
elif entity_type == "task":
cls._task_stack.append(entity_name)
elif entity_type == "agent":
cls._agent_stack.append(entity_name)
elif entity_type == "span":
cls._span_stack.append(entity_name)

@classmethod
def pop_entity(cls, entity_type: str) -> Optional[str]:
Expand All @@ -165,15 +176,16 @@ def pop_entity(cls, entity_type: str) -> Optional[str]:
Returns:
Entity name or None if stack is empty
"""
if entity_type == "workflow" and cls._workflow_stack:
return cls._workflow_stack.pop()
elif entity_type == "task" and cls._task_stack:
return cls._task_stack.pop()
elif entity_type == "agent" and cls._agent_stack:
return cls._agent_stack.pop()
elif entity_type == "span" and cls._span_stack:
return cls._span_stack.pop()
return None
with cls._lock:
if entity_type == "workflow" and cls._workflow_stack:
return cls._workflow_stack.pop()
elif entity_type == "task" and cls._task_stack:
return cls._task_stack.pop()
elif entity_type == "agent" and cls._agent_stack:
return cls._agent_stack.pop()
elif entity_type == "span" and cls._span_stack:
return cls._span_stack.pop()
return None

@classmethod
def get_current_entity_attributes(cls) -> Dict[str, str]:
Expand All @@ -183,33 +195,31 @@ def get_current_entity_attributes(cls) -> Dict[str, str]:
Returns:
Dictionary of entity attributes to add to spans
"""
attributes = {}
with cls._lock:
attributes = {}

# Add current workflow if exists
if cls._workflow_stack:
attributes[f"{Config.LIBRARY_NAME}.workflow.name"] = cls._workflow_stack[-1]
if cls._workflow_stack:
attributes[f"{Config.LIBRARY_NAME}.workflow.name"] = cls._workflow_stack[-1]

# Add current task if exists
if cls._task_stack:
attributes[f"{Config.LIBRARY_NAME}.task.name"] = cls._task_stack[-1]
if cls._task_stack:
attributes[f"{Config.LIBRARY_NAME}.task.name"] = cls._task_stack[-1]

# Add current agent if exists
if cls._agent_stack:
attributes[f"{Config.LIBRARY_NAME}.agent.name"] = cls._agent_stack[-1]
if cls._agent_stack:
attributes[f"{Config.LIBRARY_NAME}.agent.name"] = cls._agent_stack[-1]

# Add current span if exists
if cls._span_stack:
attributes[f"{Config.LIBRARY_NAME}.span.name"] = cls._span_stack[-1]
if cls._span_stack:
attributes[f"{Config.LIBRARY_NAME}.span.name"] = cls._span_stack[-1]

return attributes
return attributes

@classmethod
def clear_entity_stacks(cls) -> None:
"""Clear all entity stacks."""
cls._workflow_stack.clear()
cls._task_stack.clear()
cls._agent_stack.clear()
cls._span_stack.clear()
with cls._lock:
cls._workflow_stack.clear()
cls._task_stack.clear()
cls._agent_stack.clear()
cls._span_stack.clear()

@classmethod
def get_stack_info(cls) -> Dict[str, List[str]]:
Expand All @@ -219,12 +229,13 @@ def get_stack_info(cls) -> Dict[str, List[str]]:
Returns:
Dictionary containing all stack contents
"""
return {
"workflows": cls._workflow_stack.copy(),
"tasks": cls._task_stack.copy(),
"agents": cls._agent_stack.copy(),
"spans": cls._span_stack.copy(),
}
with cls._lock:
return {
"workflows": cls._workflow_stack.copy(),
"tasks": cls._task_stack.copy(),
"agents": cls._agent_stack.copy(),
"spans": cls._span_stack.copy(),
}

@staticmethod
def set_session_context(
Expand Down Expand Up @@ -318,13 +329,15 @@ def add_conversation(cls, conversation_type: ConversationType, role: str, conten
span = trace.get_current_span()
if not (span and getattr(span, "is_recording", lambda: False)()):
# Fallback: use the most recent active span from SessionManager
if not cls._active_spans:
with cls._lock:
active_snapshot = list(cls._active_spans)

if not active_snapshot:
logger.warning("No active span to add conversation attribute.")
return

# Find the most recent *recording* span (the last item can be a finished span)
recording_span: Optional[trace.Span] = None
for span in reversed(cls._active_spans):
for span in reversed(active_snapshot):
try:
if span and getattr(span, "is_recording", lambda: False)():
recording_span = span
Expand Down