diff --git a/src/toolregistry/integrations/langchain/integration.py b/src/toolregistry/integrations/langchain/integration.py index 47589e4..d310f6a 100644 --- a/src/toolregistry/integrations/langchain/integration.py +++ b/src/toolregistry/integrations/langchain/integration.py @@ -106,12 +106,20 @@ def from_langchain_tool( input_schema = tool.input_schema.model_json_schema() del input_schema["description"] # del it for the sake of consistency + # Use the fully-qualified class name as source_detail. + tool_cls = type(tool) + source_detail = f"{tool_cls.__module__}.{tool_cls.__qualname__}" + tool_instance = cls( name=wrapper.name, description=wrapper.tool.description, parameters=input_schema, callable=wrapper, - metadata=ToolMetadata(is_async=False), + metadata=ToolMetadata( + is_async=False, + source="langchain", + source_detail=source_detail, + ), ) if namespace: diff --git a/src/toolregistry/integrations/mcp/integration.py b/src/toolregistry/integrations/mcp/integration.py index fbf3fff..0ce20d9 100644 --- a/src/toolregistry/integrations/mcp/integration.py +++ b/src/toolregistry/integrations/mcp/integration.py @@ -239,12 +239,25 @@ def from_tool_json( ), ) + # Build a human-readable source_detail from the transport config. + transport = connection.transport + if isinstance(transport, dict): + cmd = transport.get("command", "") + args = " ".join(transport.get("args", [])) + source_detail = f"stdio:{cmd} {args}".strip() + else: + source_detail = str(transport) + tool = cls( name=normalize_tool_name(name), description=description, parameters=input_schema, callable=wrapper, - metadata=ToolMetadata(is_async=False), + metadata=ToolMetadata( + is_async=False, + source="mcp", + source_detail=source_detail, + ), ) if namespace: diff --git a/src/toolregistry/integrations/openapi/integration.py b/src/toolregistry/integrations/openapi/integration.py index a37d07f..302af73 100644 --- a/src/toolregistry/integrations/openapi/integration.py +++ b/src/toolregistry/integrations/openapi/integration.py @@ -175,12 +175,19 @@ def from_openapi_spec( persistent=persistent, ) + # Build source_detail from the base URL and endpoint path. + source_detail = f"{client_config.base_url}{path}" + tool = cls( name=func_name, description=description, parameters=parameters, callable=wrapper, - metadata=ToolMetadata(is_async=False), + metadata=ToolMetadata( + is_async=False, + source="openapi", + source_detail=source_detail, + ), ) if namespace: diff --git a/src/toolregistry/tool.py b/src/toolregistry/tool.py index 9704e03..0925e73 100644 --- a/src/toolregistry/tool.py +++ b/src/toolregistry/tool.py @@ -57,6 +57,10 @@ class ToolMetadata(BaseModel): ``ToolRegistry.execute_tool_calls()``. None means no limit. tags: Predefined tags from ToolTag enum. custom_tags: User-defined free-form string tags. + source: Origin of the tool (e.g. ``"native"``, ``"mcp"``, + ``"openapi"``, ``"langchain"``). + source_detail: Extra detail about the tool's origin (e.g. a + transport URI, spec URL, or class name). extra: Arbitrary key-value pairs for application-specific use. """ @@ -69,6 +73,21 @@ class ToolMetadata(BaseModel): tags: set[ToolTag] = Field(default_factory=set) custom_tags: set[str] = Field(default_factory=set) + source: str = "native" + """Origin of the tool. + + Indicates which integration registered the tool. Standard values: + ``"native"``, ``"mcp"``, ``"openapi"``, ``"langchain"``. + """ + + source_detail: str = "" + """Extra detail about the tool's origin. + + Free-form string providing additional context about where the tool + came from, e.g. a transport URI for MCP tools, a spec URL for + OpenAPI tools, or a class name for LangChain tools. + """ + extra: dict[str, Any] = Field(default_factory=dict) defer: bool = False diff --git a/tests/test_tool_source.py b/tests/test_tool_source.py new file mode 100644 index 0000000..15254f2 --- /dev/null +++ b/tests/test_tool_source.py @@ -0,0 +1,153 @@ +"""Tests for ToolMetadata.source and ToolMetadata.source_detail fields.""" + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from toolregistry import Tool, ToolMetadata + + +# --------------------------------------------------------------------------- +# Native tools +# --------------------------------------------------------------------------- + + +def _dummy(x: int) -> int: + """Return x.""" + return x + + +class TestNativeToolSource: + """Native tools should default to source='native' with empty detail.""" + + def test_default_source(self): + m = ToolMetadata() + assert m.source == "native" + assert m.source_detail == "" + + def test_from_function_default_source(self): + tool = Tool.from_function(_dummy) + assert tool.metadata.source == "native" + assert tool.metadata.source_detail == "" + + def test_explicit_source_override(self): + m = ToolMetadata(source="custom", source_detail="some detail") + assert m.source == "custom" + assert m.source_detail == "some detail" + + +# --------------------------------------------------------------------------- +# MCP tools +# --------------------------------------------------------------------------- + + +class TestMCPToolSource: + """MCPTool.from_tool_json should set source='mcp'.""" + + def test_mcp_source_with_url_transport(self): + mcp_types = pytest.importorskip("mcp.types") + + from toolregistry.integrations.mcp.integration import MCPTool + + tool_spec = mcp_types.Tool( + name="echo", + description="Echo input", + inputSchema={"type": "object", "properties": {"msg": {"type": "string"}}}, + ) + + connection = MagicMock() + connection.transport = "http://localhost:8080/sse" + + tool = MCPTool.from_tool_json(tool_spec, connection=connection) + assert tool.metadata.source == "mcp" + assert tool.metadata.source_detail == "http://localhost:8080/sse" + + def test_mcp_source_with_stdio_transport(self): + mcp_types = pytest.importorskip("mcp.types") + + from toolregistry.integrations.mcp.integration import MCPTool + + tool_spec = mcp_types.Tool( + name="greet", + description="Greet user", + inputSchema={"type": "object", "properties": {"name": {"type": "string"}}}, + ) + + connection = MagicMock() + connection.transport = {"command": "uvx", "args": ["my-server"]} + + tool = MCPTool.from_tool_json(tool_spec, connection=connection) + assert tool.metadata.source == "mcp" + assert tool.metadata.source_detail == "stdio:uvx my-server" + + +# --------------------------------------------------------------------------- +# OpenAPI tools +# --------------------------------------------------------------------------- + + +class TestOpenAPIToolSource: + """OpenAPITool.from_openapi_spec should set source='openapi'.""" + + def test_openapi_source(self): + pytest.importorskip("jsonref") + + from toolregistry.integrations.openapi.integration import OpenAPITool + from toolregistry.utils import HttpxClientConfig + + client_config = HttpxClientConfig(base_url="https://api.example.com") + spec: dict[str, Any] = { + "operationId": "listItems", + "summary": "List items", + "parameters": [], + } + + tool = OpenAPITool.from_openapi_spec( + client_config=client_config, + path="/items", + method="get", + spec=spec, + ) + assert tool.metadata.source == "openapi" + assert tool.metadata.source_detail == "https://api.example.com/items" + + +# --------------------------------------------------------------------------- +# LangChain tools +# --------------------------------------------------------------------------- + + +class TestLangChainToolSource: + """LangChainTool.from_langchain_tool should set source='langchain'.""" + + def test_langchain_source(self): + pytest.importorskip("langchain_core") + + from langchain_core.tools import BaseTool as LCBaseTool + from pydantic import BaseModel, Field + + from toolregistry.integrations.langchain.integration import LangChainTool + + class AddInput(BaseModel): + """Input for adding two numbers.""" + + a: int = Field(description="First number") + b: int = Field(description="Second number") + + class MockAddTool(LCBaseTool): + name: str = "add_numbers" + description: str = "Add two numbers together" + args_schema: type[BaseModel] = AddInput + + def _run(self, a: int, b: int) -> int: + return a + b + + async def _arun(self, a: int, b: int) -> int: + return a + b + + lc_tool = MockAddTool() + tool = LangChainTool.from_langchain_tool(lc_tool) + assert tool.metadata.source == "langchain" + # source_detail should contain the class name + assert "MockAddTool" in tool.metadata.source_detail