From acad2d8ba6914106c04ad86868a9bf8a65e3f120 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Wed, 15 Apr 2026 11:07:42 +0800 Subject: [PATCH 1/8] feat: add persistent custom operator registration functionality --- data_juicer/ops/__init__.py | 45 ++-- data_juicer/tools/op_search.py | 214 +++++++++++++-- data_juicer/utils/custom_op.py | 467 +++++++++++++++++++++++++++++++++ data_juicer/utils/registry.py | 11 + docs/DeveloperGuide.md | 47 +++- docs/DeveloperGuide_ZH.md | 47 +++- tests/tools/test_op_search.py | 80 ++++++ tests/utils/test_custom_op.py | 429 ++++++++++++++++++++++++++++++ 8 files changed, 1299 insertions(+), 41 deletions(-) create mode 100644 data_juicer/utils/custom_op.py create mode 100644 tests/utils/test_custom_op.py diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py index 15b2b7bda8..b151807c0f 100644 --- a/data_juicer/ops/__init__.py +++ b/data_juicer/ops/__init__.py @@ -14,7 +14,18 @@ def timing_context(description): # yapf: disable with timing_context('Importing operator modules'): + # 1. Built-in operators (registered via @OPERATORS.register_module decorators + # that fire as each sub-package is imported) + # 2. Persistent custom operators (loaded from ~/.data_juicer/op_registry.json; + # no-op when the registry file does not exist) + from data_juicer.utils.custom_op import ( + load_persistent_custom_ops as _load_persistent, + ) + from . import aggregator, deduplicator, filter, grouper, mapper, pipeline, selector + _load_persistent() + del _load_persistent + from .base_op import ( ATTRIBUTION_FILTERS, NON_STATS_FILTERS, @@ -39,21 +50,21 @@ def timing_context(description): ) __all__ = [ - 'load_ops', - 'Filter', - 'Mapper', - 'Deduplicator', - 'Selector', - 'Grouper', - 'Aggregator', - 'UNFORKABLE', - 'NON_STATS_FILTERS', - 'OPERATORS', - 'TAGGING_OPS', - 'Pipeline', - 'OPEnvSpec', - 'op_requirements_to_op_env_spec', - 'OPEnvManager', - 'analyze_lazy_loaded_requirements', - 'analyze_lazy_loaded_requirements_for_code_file', + "load_ops", + "Filter", + "Mapper", + "Deduplicator", + "Selector", + "Grouper", + "Aggregator", + "UNFORKABLE", + "NON_STATS_FILTERS", + "OPERATORS", + "TAGGING_OPS", + "Pipeline", + "OPEnvSpec", + "op_requirements_to_op_env_spec", + "OPEnvManager", + "analyze_lazy_loaded_requirements", + "analyze_lazy_loaded_requirements_for_code_file", ] diff --git a/data_juicer/tools/op_search.py b/data_juicer/tools/op_search.py index e7a1f79a35..32cda3f0d9 100644 --- a/data_juicer/tools/op_search.py +++ b/data_juicer/tools/op_search.py @@ -159,22 +159,42 @@ class OPRecord: def __init__(self, name: str, op_cls: type, op_type: Optional[str] = None): self.name = name - self.type = op_type or op_cls.__module__.split(".")[2].lower() + + # --- module path: + # handling for custom ops --- + if op_type: + self.type = op_type + else: + module_parts = op_cls.__module__.split(".") + if len(module_parts) >= 3: + self.type = module_parts[2].lower() + else: + self.type = self._search_mro_for_type(op_cls) if self.type not in op_type_list: self.type = self._search_mro_for_type(op_cls) + self.desc = op_cls.__doc__ or "" self.tags = analyze_tag_from_cls(op_cls, name) self.sig = inspect.signature(op_cls.__init__) self.param_desc = extract_param_docstring(op_cls.__init__.__doc__ or "") self.param_desc_map = self._parse_param_desc() - self.source_path = str(get_source_path(op_cls)) - self.test_path = None - test_path = f"tests/ops/{self.type}/test_{self.name}.py" - if not (PROJECT_ROOT / test_path).exists(): - test_path = find_test_by_searching_content(PROJECT_ROOT / "tests", op_cls.__name__ + "Test") or test_path + # --- source path: handling for custom ops --- + try: + self.source_path = str(get_source_path(op_cls)) + except ValueError: + self.source_path = str(Path(inspect.getfile(op_cls))) - self.test_path = str(test_path) + # --- test path: handling for custom ops --- + try: + test_path = f"tests/ops/{self.type}/test_{self.name}.py" + if not (PROJECT_ROOT / test_path).exists(): + test_path = ( + find_test_by_searching_content(PROJECT_ROOT / "tests", op_cls.__name__ + "Test") or test_path + ) + self.test_path = str(test_path) + except Exception: + self.test_path = None def __getitem__(self, item): try: @@ -441,26 +461,176 @@ def records_map(self): return self.all_ops -def main(query, tags, op_type): +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def _build_parser(): + import argparse + + parser = argparse.ArgumentParser( + prog="python -m data_juicer.tools.op_search", + description="Data-Juicer Operator Search & Query Tool", + ) + sub = parser.add_subparsers(dest="command", help="Available commands") + + # --- list --- + sub.add_parser( + "list", + help="List all operators (built-in + custom)", + ) + + # --- info --- + p_info = sub.add_parser( + "info", + help="Show detailed information about an operator", + ) + p_info.add_argument("name", help="Operator name") + + # --- search --- + p_search = sub.add_parser( + "search", + help="Search operators by keyword, regex, or tags", + ) + p_search.add_argument( + "query", + nargs="?", + default=None, + help="Search query (natural language or regex pattern)", + ) + p_search.add_argument( + "--mode", + choices=["bm25", "regex"], + default="bm25", + help="Search mode (default: bm25)", + ) + p_search.add_argument( + "--tags", + nargs="+", + default=None, + help="Filter by tags (e.g., gpu, cpu, text, image)", + ) + p_search.add_argument( + "--type", + dest="op_type", + default=None, + help="Filter by operator type (e.g., mapper, filter)", + ) + p_search.add_argument( + "--top-k", + type=int, + default=10, + help="Maximum number of results (default: 10)", + ) + + return parser + + +def _cmd_list(args) -> int: + """List all operators (built-in + custom).""" + from data_juicer.utils.custom_op import list_registered + + custom_info = list_registered() + custom_names = set(custom_info.get("custom_operators", {}).keys()) + all_names = sorted(OPERATORS.modules.keys()) + + print(f"Total operators: {len(all_names)}") + print(f" Built-in: {len(all_names) - len(custom_names)}") + print(f" Custom: {len(custom_names)}") + print() + for name in all_names: + marker = " [custom]" if name in custom_names else "" + print(f" {name}{marker}") + return 0 + + +def _cmd_info(args) -> int: + """Show detailed information about an operator.""" + import sys + + op_cls = OPERATORS.modules.get(args.name) + if op_cls is None: + print(f"Operator '{args.name}' not found.", file=sys.stderr) + return 1 + + record = OPRecord(name=args.name, op_cls=op_cls) + info = record.to_dict() + + print(f"Name: {info['name']}") + print(f"Type: {info['type']}") + print(f"Tags: {', '.join(info['tags']) if info['tags'] else 'none'}") + print(f"Source: {info['source_path']}") + print(f"Test: {info['test_path'] or 'none'}") + print(f"Signature: {info['sig']}") + print() + if info["desc"]: + print("Description:") + print(f" {info['desc'].strip()}") + print() + if info["param_desc_map"]: + print("Parameters:") + for pname, pdesc in info["param_desc_map"].items(): + print(f" {pname}: {pdesc}") + + return 0 + + +def _cmd_search(args) -> int: + """Search operators by keyword, regex, or tags.""" searcher = OPSearcher(include_formatter=True) - results = searcher.search_by_bm25(query=query, tags=tags, op_type=op_type) + query = args.query + tags = args.tags + op_type = args.op_type - print(f"\nFound {len(results)} operators:") + if args.mode == "regex" and query: + results = searcher.search_by_regex(query=query, tags=tags, op_type=op_type) + elif query: + results = searcher.search_by_bm25(query=query, tags=tags, op_type=op_type, top_k=args.top_k) + else: + results = searcher.search(tags=tags, op_type=op_type) + + print(f"Found {len(results)} operator(s):") for op in results: - print(f"\n[{op['type'].upper()}] {op['name']}") - print(f"Tags: {', '.join(op['tags'])}") - print(f"Description: {op['desc']}") - print(f"Parameters: {op['param_desc']}") - print(f"Parameter Descriptions: {op['param_desc_map']}") - print(f"Signature: {op['sig']}") - print("-" * 50) + print(f"\n [{op['type'].upper()}] {op['name']}") + print(f" Tags: {', '.join(op['tags'])}") + desc = (op.get("desc") or "").strip() + if desc: + first_line = desc.split("\n")[0].strip() + if len(first_line) > 80: + first_line = first_line[:77] + "..." + print(f" Desc: {first_line}") + + return 0 - print(searcher.records_map["nlpaug_en_mapper"]["source_path"]) - print(searcher.records_map["nlpaug_en_mapper"].test_path) + +_COMMAND_MAP = { + "list": _cmd_list, + "info": _cmd_info, + "search": _cmd_search, +} + + +def main(argv=None) -> int: + """CLI entry point for operator search & query.""" + + parser = _build_parser() + args = parser.parse_args(argv) + + if not args.command: + parser.print_help() + return 1 + + handler = _COMMAND_MAP.get(args.command) + if handler is None: + parser.print_help() + return 1 + + return handler(args) if __name__ == "__main__": - tags = [] - op_type = "formatter" - main(query="json", tags=tags, op_type=op_type) + import sys + + sys.exit(main()) diff --git a/data_juicer/utils/custom_op.py b/data_juicer/utils/custom_op.py new file mode 100644 index 0000000000..3d834393db --- /dev/null +++ b/data_juicer/utils/custom_op.py @@ -0,0 +1,467 @@ +""" +Custom Operator Management +=========================== + +Manages a persistent JSON registry at ``~/.data_juicer/op_registry.json`` +so that custom operators survive across processes. + +The registry file stores source-file paths keyed by operator name. +On startup, ``load_persistent_custom_ops()`` replays those registrations +into the in-process ``OPERATORS`` registry, automatically cleaning up +entries whose source files no longer exist. + +Also provides a CLI for managing custom operators:: + + python -m data_juicer.utils.custom_op list + python -m data_juicer.utils.custom_op register /path/to/my_mapper.py + python -m data_juicer.utils.custom_op unregister my_mapper + python -m data_juicer.utils.custom_op reset +""" + +import argparse +import inspect +import json +import os +import sys +import tempfile +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + +from loguru import logger + +# --------------------------------------------------------------------------- +# Path management +# --------------------------------------------------------------------------- + + +def get_registry_path() -> Path: + """Return the path to the persistent op registry JSON file. + + Defaults to ``~/.data_juicer/op_registry.json``. + Override with the ``DJ_OP_REGISTRY`` environment variable. + """ + override = os.environ.get("DJ_OP_REGISTRY") + if override: + return Path(override) + return Path.home() / ".data_juicer" / "op_registry.json" + + +# --------------------------------------------------------------------------- +# Low-level read / write helpers +# --------------------------------------------------------------------------- + +_EMPTY_REGISTRY: Dict = {"version": 1, "custom_operators": {}} + + +def _read_registry() -> dict: + """Read the JSON registry. Returns the empty structure when the file + does not exist or is malformed.""" + path = get_registry_path() + if not path.exists(): + return json.loads(json.dumps(_EMPTY_REGISTRY)) # deep copy + try: + with open(path, "r", encoding="utf-8") as fh: + data = json.load(fh) + if not isinstance(data, dict) or "custom_operators" not in data: + logger.warning(f"Malformed op registry at {path}, resetting to empty.") + return json.loads(json.dumps(_EMPTY_REGISTRY)) + return data + except (json.JSONDecodeError, OSError) as exc: + logger.warning(f"Failed to read op registry at {path}: {exc}") + return json.loads(json.dumps(_EMPTY_REGISTRY)) + + +def _write_registry(data: dict) -> None: + """Atomically write *data* to the registry JSON file. + + Uses write-to-tmp + ``os.replace`` to avoid partial writes. + """ + path = get_registry_path() + path.parent.mkdir(parents=True, exist_ok=True) + + # Write to a temp file in the same directory, then atomically replace. + fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp", prefix=".op_registry_") + try: + with os.fdopen(fd, "w", encoding="utf-8") as fh: + json.dump(data, fh, indent=2, ensure_ascii=False) + fh.write("\n") + os.replace(tmp_path, str(path)) + except BaseException: + # Clean up the temp file on failure. + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + +# --------------------------------------------------------------------------- +# Custom Op management +# --------------------------------------------------------------------------- + + +def register_persistent(paths: List[str]) -> dict: + """Register custom operators to the persistent registry **and** the + current-process ``OPERATORS``. + + *paths* is a list of file / directory paths (same semantics as + ``load_custom_operators``). + + Returns ``{"registered": [...], "warnings": [...]}``. + """ + from data_juicer.config.config import load_custom_operators + from data_juicer.ops.base_op import OPERATORS + + # Snapshot operator names before loading. + before = set(OPERATORS.modules.keys()) + + warnings: List[str] = [] + valid_paths: List[str] = [] + for p in paths: + abs_p = os.path.abspath(p) + if not os.path.exists(abs_p): + warnings.append(f"Path does not exist: {abs_p}") + continue + valid_paths.append(abs_p) + + if valid_paths: + load_custom_operators(valid_paths) + + # Diff to find newly registered names. + after = set(OPERATORS.modules.keys()) + new_names = sorted(after - before) + + # Build a mapping from op name -> source path. + name_to_path: Dict[str, str] = {} + for p in valid_paths: + abs_p = os.path.abspath(p) + if os.path.isfile(abs_p): + # Find which new names came from this file. + for n in new_names: + op_cls = OPERATORS.modules.get(n) + if op_cls is not None: + try: + cls_file = os.path.abspath(inspect.getfile(op_cls)) + if cls_file == abs_p: + name_to_path[n] = abs_p + except (TypeError, OSError): + pass + elif os.path.isdir(abs_p): + # Directory – inspect each new op to find its file. + for n in new_names: + if n in name_to_path: + continue + op_cls = OPERATORS.modules.get(n) + if op_cls is not None: + try: + cls_file = os.path.abspath(inspect.getfile(op_cls)) + if cls_file.startswith(abs_p): + name_to_path[n] = cls_file + except (TypeError, OSError): + pass + + # Fallback: for any new name not yet mapped, try inspect. + for n in new_names: + if n not in name_to_path: + op_cls = OPERATORS.modules.get(n) + if op_cls is not None: + try: + name_to_path[n] = os.path.abspath(inspect.getfile(op_cls)) + except (TypeError, OSError): + warnings.append(f"Could not determine source path for '{n}'") + + # Persist. + registry = _read_registry() + now = datetime.now().isoformat(timespec="seconds") + for n in new_names: + src = name_to_path.get(n) + if src: + registry["custom_operators"][n] = { + "source_path": src, + "registered_at": now, + } + + _write_registry(registry) + + return {"registered": new_names, "warnings": warnings} + + +def unregister_ops(names: List[str]) -> dict: + """Remove the given custom operators from the persistent registry and + the current-process ``OPERATORS``. + + Returns ``{"removed": [...], "not_found": [...]}``. + """ + from data_juicer.ops.base_op import OPERATORS + + registry = _read_registry() + removed: List[str] = [] + not_found: List[str] = [] + + for name in names: + if name in registry["custom_operators"]: + del registry["custom_operators"][name] + removed.append(name) + else: + not_found.append(name) + + # Also remove from in-process registry. + OPERATORS.unregister_module(name) + + # Remove from sys.modules if present. + # The module name is typically the stem of the source file. + to_remove = [k for k in sys.modules if k == name or k.endswith(f".{name}")] + for k in to_remove: + del sys.modules[k] + + _write_registry(registry) + return {"removed": removed, "not_found": not_found} + + +def reset_registry() -> dict: + """Clear **all** custom operators from the persistent registry and the + current-process ``OPERATORS``. + + Returns ``{"removed": [...]}``. + """ + from data_juicer.ops.base_op import OPERATORS + + registry = _read_registry() + removed = sorted(registry["custom_operators"].keys()) + + for name in removed: + OPERATORS.unregister_module(name) + to_remove = [k for k in sys.modules if k == name or k.endswith(f".{name}")] + for k in to_remove: + del sys.modules[k] + + registry["custom_operators"] = {} + _write_registry(registry) + + return {"removed": removed} + + +# --------------------------------------------------------------------------- +# Query +# --------------------------------------------------------------------------- + + +def list_registered() -> dict: + """Return the contents of the persistent registry. + + Returns ``{"custom_operators": {...}}``. + """ + registry = _read_registry() + return {"custom_operators": registry.get("custom_operators", {})} + + +# --------------------------------------------------------------------------- +# Startup loading +# --------------------------------------------------------------------------- + + +def load_persistent_custom_ops() -> dict: + """Load all custom operators from the persistent registry into the + current-process ``OPERATORS``. + + Entries whose source files no longer exist are automatically removed + from the registry and a warning is logged. + + Returns ``{"loaded": [...], "cleaned": [...], "warnings": [...]}``. + """ + from data_juicer.config.config import load_custom_operators + + registry = _read_registry() + custom_ops = registry.get("custom_operators", {}) + + if not custom_ops: + return {"loaded": [], "cleaned": [], "warnings": []} + + loaded: List[str] = [] + cleaned: List[str] = [] + warnings: List[str] = [] + + # Validate paths first. + valid_entries: Dict[str, dict] = {} + for name, meta in list(custom_ops.items()): + src = meta.get("source_path", "") + if not src or not os.path.exists(src): + logger.warning(f"Custom op '{name}' source not found at '{src}', " f"removing from registry.") + cleaned.append(name) + warnings.append(f"Cleaned stale entry '{name}': source '{src}' not found.") + else: + valid_entries[name] = meta + + # If we cleaned anything, persist the updated registry. + if cleaned: + registry["custom_operators"] = valid_entries + _write_registry(registry) + + # Load valid entries. + # Group by source_path to avoid loading the same file twice. + paths_to_load: List[str] = [] + seen_paths: set = set() + for name, meta in valid_entries.items(): + src = meta["source_path"] + if src not in seen_paths: + paths_to_load.append(src) + seen_paths.add(src) + + if paths_to_load: + try: + load_custom_operators(paths_to_load) + loaded = sorted(valid_entries.keys()) + except Exception as exc: + msg = f"Failed to load persistent custom ops: {exc}" + logger.error(msg) + warnings.append(msg) + + return {"loaded": loaded, "cleaned": cleaned, "warnings": warnings} + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="python -m data_juicer.utils.custom_op", + description="Data-Juicer Custom Operator Management Tool " "(does not affect built-in operators)", + ) + sub = parser.add_subparsers(dest="command", help="Available commands") + + # --- list --- + sub.add_parser( + "list", + help="List registered custom operators", + ) + + # --- register --- + p_reg = sub.add_parser( + "register", + help="Register custom operator(s) persistently", + ) + p_reg.add_argument( + "paths", + nargs="+", + help="Path(s) to custom operator file(s) or directory(ies)", + ) + + # --- unregister --- + p_unreg = sub.add_parser( + "unregister", + help="Unregister custom operator(s) from the persistent registry", + ) + p_unreg.add_argument( + "names", + nargs="+", + help="Name(s) of custom operators to remove", + ) + + # --- reset --- + sub.add_parser( + "reset", + help="Clear all custom operators from the persistent registry", + ) + + return parser + + +def _cmd_list(args) -> int: + """List registered custom operators.""" + result = list_registered() + custom_ops = result.get("custom_operators", {}) + if not custom_ops: + print("No custom operators registered.") + return 0 + print(f"Custom operators ({len(custom_ops)}):") + for name, meta in sorted(custom_ops.items()): + src = meta.get("source_path", "?") + reg_at = meta.get("registered_at", "?") + print(f" {name}") + print(f" source: {src}") + print(f" registered_at: {reg_at}") + return 0 + + +def _cmd_register(args) -> int: + """Register custom operator(s) persistently.""" + result = register_persistent(args.paths) + registered = result.get("registered", []) + warnings = result.get("warnings", []) + + if registered: + print(f"Registered {len(registered)} operator(s):") + for name in registered: + print(f" + {name}") + else: + print("No new operators registered.") + + for warning in warnings: + print(f" WARNING: {warning}", file=sys.stderr) + + return 0 + + +def _cmd_unregister(args) -> int: + """Unregister custom operator(s) from the persistent registry.""" + result = unregister_ops(args.names) + removed = result.get("removed", []) + not_found = result.get("not_found", []) + + if removed: + print(f"Removed {len(removed)} operator(s):") + for name in removed: + print(f" - {name}") + + if not_found: + print(f"Not found ({len(not_found)}):") + for name in not_found: + print(f" ? {name}") + + return 0 + + +def _cmd_reset(args) -> int: + """Clear all custom operators from the persistent registry.""" + result = reset_registry() + removed = result.get("removed", []) + + if removed: + print(f"Removed {len(removed)} custom operator(s):") + for name in removed: + print(f" - {name}") + else: + print("Registry was already empty.") + return 0 + + +_COMMAND_MAP = { + "list": _cmd_list, + "register": _cmd_register, + "unregister": _cmd_unregister, + "reset": _cmd_reset, +} + + +def main(argv: Optional[List[str]] = None) -> int: + """CLI entry point for custom operator management.""" + parser = _build_parser() + args = parser.parse_args(argv) + + if not args.command: + parser.print_help() + return 1 + + handler = _COMMAND_MAP.get(args.command) + if handler is None: + parser.print_help() + return 1 + + return handler(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/data_juicer/utils/registry.py b/data_juicer/utils/registry.py index e438d015f9..c8612cea46 100644 --- a/data_juicer/utils/registry.py +++ b/data_juicer/utils/registry.py @@ -82,6 +82,17 @@ def _register_module(self, module_name=None, module_cls=None, force=False): self._modules[module_name] = module_cls module_cls._name = module_name + def unregister_module(self, module_name: str) -> bool: + """Remove a module from the registry by name. + + :param module_name: name of the module to remove + :return: True if removed, False if not found. + """ + if module_name in self._modules: + del self._modules[module_name] + return True + return False + def register_module(self, module_name: str = None, module_cls: type = None, force=False): """ Register module class object to registry with the specified modulename. diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index 669af22f21..d1243b3bfd 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -2,6 +2,8 @@ - [How-to Guide for Developers](#how-to-guide-for-developers) - [1. Build Your Own OPs Quickly](#1-build-your-own-ops-quickly) + - [Persistent Custom Operator Registration](#persistent-custom-operator-registration) + - [Build an OP Step by Step](#build-an-op-step-by-step) - [2. Build Your Own Data Recipes and Configs](#2-build-your-own-data-recipes-and-configs) - [2.1 Fruitful Config Sources \& Type Hints](#21-fruitful-config-sources--type-hints) - [2.2 Hierarchical Configs and Helps](#22-hierarchical-configs-and-helps) @@ -27,6 +29,43 @@ > The development process of the following example takes directly adding operators in the corresponding module of the source code as an example. If an operator is added externally, the new operator can be registered by passing the parameter `--custom-operator-paths` or configuring the `custom_operator_paths` parameter in the yaml file, for example: `custom_operator_paths: ['/path/to/new/op.py', '/path/to/new/ops/directory/]`. +### Persistent Custom Operator Registration + +In addition to the per-run `--custom-operator-paths` approach above, Data-Juicer provides a **persistent custom operator registry** so that externally developed operators survive across processes and sessions without repeating configuration. + +The registry is stored at `~/.data_juicer/op_registry.json` (override with the `DJ_OP_REGISTRY` environment variable). Manage it via the CLI: + +```bash +# Register custom operator(s) — accepts file or directory paths +python -m data_juicer.utils.custom_op register /path/to/my_mapper.py + +# List all registered custom operators +python -m data_juicer.utils.custom_op list + +# Unregister by operator name +python -m data_juicer.utils.custom_op unregister my_mapper + +# Clear all custom operator registrations +python -m data_juicer.utils.custom_op reset +``` + +Once registered, custom operators are **automatically loaded** on every Data-Juicer startup. Stale entries (whose source files no longer exist) are cleaned up automatically. + +You can also search and inspect both built-in and custom operators with the operator search tool: + +```bash +# List all operators (built-in + custom) +python -m data_juicer.tools.op_search list + +# Show detailed info for a specific operator +python -m data_juicer.tools.op_search info my_mapper + +# Search operators by keyword +python -m data_juicer.tools.op_search search "text length" +``` + +### Build an OP Step by Step + Assuming we want to add a new Filter operator called "TextLengthFilter" to get corpus of expected text length, we can follow the following steps to build it. 1. (Optional) If the new OP defines some statistical variables, please add the corresponding new `StatsKeysConstant` attribute in `data_juicer/utils/constant.py` for unified management. @@ -120,7 +159,13 @@ process: max_len: 1000 ``` -6. Community contributors can submit corresponding operator PRs and work with the Data-Juicer team to gradually improve it in subsequent PRs. Please see more details [below](#4-contribution-to-the-open-source-community). We greatly welcome co-construction and will [highlight acknowledgements](https://github.com/datajuicer/data-juicer?tab=readme-ov-file#contribution-and-acknowledgements)! +6. (Optional) If you develop custom operators outside the Data-Juicer source tree, you can **persistently register** them so they are available across all future sessions without adding `custom_operator_paths` every time: + + ```bash + python -m data_juicer.utils.custom_op register /path/to/text_length_filter.py + ``` + +7. Community contributors can submit corresponding operator PRs and work with the Data-Juicer team to gradually improve it in subsequent PRs. Please see more details [below](#4-contribution-to-the-open-source-community). We greatly welcome co-construction and will [highlight acknowledgements](https://github.com/datajuicer/data-juicer?tab=readme-ov-file#contribution-and-acknowledgements)! ## 2. Build Your Own Data Recipes and Configs diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index 4bfa947685..6623738148 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -2,6 +2,8 @@ - [开发者指南](#开发者指南) - [1. 快速构建你自己的算子](#1-快速构建你自己的算子) + - [持久化自定义算子注册](#持久化自定义算子注册) + - [逐步构建一个算子](#逐步构建一个算子) - [2. 构建你自己的数据菜谱和配置项](#2-构建你自己的数据菜谱和配置项) - [2.1 丰富的配置源和类型提示](#21-丰富的配置源和类型提示) - [2.2 层次化的配置和帮助](#22-层次化的配置和帮助) @@ -27,6 +29,43 @@ > 以下示例的开发过程以直接在源码对应模块中添加算子为例。如果外部添加算子,可以通过传参`--custom-operator-paths` 或 yaml配置文件中配置`custom_operator_paths`参数注册新算子,例如:`custom_operator_paths: ['/path/to/new/op.py', '/path/to/new/ops/directory/]`。 +### 持久化自定义算子注册 + +除了上述每次运行时指定 `--custom-operator-paths` 的方式外,Data-Juicer 还提供了**持久化自定义算子注册表**,使外部开发的算子能够跨进程、跨会话持续生效,无需重复配置。 + +注册表存储在 `~/.data_juicer/op_registry.json`(可通过环境变量 `DJ_OP_REGISTRY` 覆盖路径)。通过 CLI 管理: + +```bash +# 注册自定义算子 — 支持文件或目录路径 +python -m data_juicer.utils.custom_op register /path/to/my_mapper.py + +# 列出所有已注册的自定义算子 +python -m data_juicer.utils.custom_op list + +# 按算子名称取消注册 +python -m data_juicer.utils.custom_op unregister my_mapper + +# 清除所有自定义算子注册 +python -m data_juicer.utils.custom_op reset +``` + +注册完成后,自定义算子会在每次 Data-Juicer 启动时**自动加载**。源文件已不存在的失效条目会被自动清理。 + +你还可以使用算子搜索工具查询和检视内置算子与自定义算子: + +```bash +# 列出所有算子(内置 + 自定义) +python -m data_juicer.tools.op_search list + +# 查看某个算子的详细信息 +python -m data_juicer.tools.op_search info my_mapper + +# 按关键词搜索算子 +python -m data_juicer.tools.op_search search "text length" +``` + +### 逐步构建一个算子 + 下面以 "TextLengthFilter" 的算子(过滤仅包含预期文本长度的样本语料)为例,展示相应开发构建过程。 1. (可选) 如果该算子定义了某个统计变量,那么请在 `data_juicer/utils/constant.py` 文件中添加一个新的`StatsKeys`属性来统一保存管理。 @@ -121,7 +160,13 @@ process: max_len: 1000 ``` -6. 社区贡献者可在alpha状态后就提相应算子PR。此后该贡献者可以与Data-Juicer团队一起在后续PR中,将其渐进完善到beta和stable版本。更多细节请参考下方第4节。我们非常欢迎共建,并会[高亮致谢](https://github.com/datajuicer/data-juicer?tab=readme-ov-file#contribution-and-acknowledgements)! +6. (可选)如果你在 Data-Juicer 源码树之外开发自定义算子,可以**持久化注册**它们,这样在后续所有会话中都可以直接使用,无需每次都添加 `custom_operator_paths`: + + ```bash + python -m data_juicer.utils.custom_op register /path/to/text_length_filter.py + ``` + +7. 社区贡献者可在alpha状态后就提相应算子PR。此后该贡献者可以与Data-Juicer团队一起在后续PR中,将其渐进完善到beta和stable版本。更多细节请参考下方第4节。我们非常欢迎共建,并会[高亮致谢](https://github.com/datajuicer/data-juicer?tab=readme-ov-file#contribution-and-acknowledgements)! ## 2. 构建你自己的数据菜谱和配置项 diff --git a/tests/tools/test_op_search.py b/tests/tools/test_op_search.py index 856c281e96..c684f83962 100644 --- a/tests/tools/test_op_search.py +++ b/tests/tools/test_op_search.py @@ -1,7 +1,10 @@ +import os +import tempfile import unittest from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase from data_juicer.tools.op_search import OPSearcher +from data_juicer.tools.op_search import main as op_search_main class OPRecordTest(DataJuicerTestCaseBase): @@ -250,5 +253,82 @@ def test_records_map_deprecated_still_returns_all_ops(self): self.assertEqual(records_map, searcher.all_ops) +# --------------------------------------------------------------------------- +# CLI tests +# --------------------------------------------------------------------------- + + +class OpSearchCLIListTest(DataJuicerTestCaseBase): + """Test the 'list' CLI sub-command.""" + + def setUp(self): + self._tmp_dir = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp_dir.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp_dir.cleanup() + + def test_list(self): + rc = op_search_main(["list"]) + self.assertEqual(rc, 0) + +class OpSearchCLIInfoTest(DataJuicerTestCaseBase): + """Test the 'info' CLI sub-command.""" + + def setUp(self): + self._tmp_dir = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp_dir.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp_dir.cleanup() + + def test_info_builtin(self): + """Info on a built-in op should succeed.""" + rc = op_search_main(["info", "text_length_filter"]) + self.assertEqual(rc, 0) + + def test_info_not_found(self): + rc = op_search_main(["info", "no_such_op_xyz"]) + self.assertEqual(rc, 1) + +class OpSearchCLISearchTest(DataJuicerTestCaseBase): + """Test the 'search' CLI sub-command.""" + + def setUp(self): + self._tmp_dir = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp_dir.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp_dir.cleanup() + + def test_search_bm25(self): + rc = op_search_main(["search", "text length"]) + self.assertEqual(rc, 0) + + def test_search_regex(self): + rc = op_search_main(["search", "text.*filter", "--mode", "regex"]) + self.assertEqual(rc, 0) + + def test_search_by_tags(self): + rc = op_search_main(["search", "--tags", "cpu"]) + self.assertEqual(rc, 0) + + def test_search_by_type(self): + rc = op_search_main(["search", "--type", "mapper"]) + self.assertEqual(rc, 0) + +class OpSearchCLINoCommandTest(DataJuicerTestCaseBase): + """Test calling op_search CLI with no command.""" + + def test_no_command(self): + rc = op_search_main([]) + self.assertEqual(rc, 1) + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/test_custom_op.py b/tests/utils/test_custom_op.py new file mode 100644 index 0000000000..a9e54bf24d --- /dev/null +++ b/tests/utils/test_custom_op.py @@ -0,0 +1,429 @@ +""" +Tests for data_juicer.utils.custom_op + +These tests exercise the persistent custom-op registry end-to-end: +register, unregister, reset, list, load, stale-entry cleanup, +and the CLI sub-commands. + +Each test uses a private temp directory for the registry JSON and a +throwaway custom-op source file so that nothing leaks between tests. +""" + +import os +import subprocess +import sys +import tempfile +import textwrap +import unittest + +from data_juicer.ops.base_op import OPERATORS +from data_juicer.utils.custom_op import ( + _read_registry, + _write_registry, + get_registry_path, + list_registered, + load_persistent_custom_ops, + main as custom_op_main, + register_persistent, + reset_registry, + unregister_ops, +) +from data_juicer.utils.registry import Registry +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + +def _make_custom_op_file(tmp_dir: str, op_name: str = "my_test_mapper") -> str: + """Create a minimal custom mapper .py file and return its path.""" + code = textwrap.dedent(f"""\ + from data_juicer.ops.base_op import OPERATORS, Mapper + + @OPERATORS.register_module('{op_name}') + class MyTestMapper(Mapper): + \"\"\"A test custom mapper.\"\"\" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def process_single(self, sample): + return sample + """) + path = os.path.join(tmp_dir, f"{op_name}.py") + with open(path, "w", encoding="utf-8") as f: + f.write(code) + return path + +# --------------------------------------------------------------------------- +# Unit tests for low-level API +# --------------------------------------------------------------------------- + +class RegistryUnregisterTest(DataJuicerTestCaseBase): + """Test Registry.unregister_module (new method).""" + + def test_unregister_existing(self): + reg = Registry("test_unreg") + + class Dummy: + pass + + reg.register_module("dummy_op", Dummy) + self.assertIn("dummy_op", reg.modules) + result = reg.unregister_module("dummy_op") + self.assertTrue(result) + self.assertNotIn("dummy_op", reg.modules) + + def test_unregister_nonexistent(self): + reg = Registry("test_unreg2") + result = reg.unregister_module("no_such_op") + self.assertFalse(result) + +class RegistryPathTest(DataJuicerTestCaseBase): + """Test get_registry_path with and without env override.""" + + def test_default_path(self): + env_backup = os.environ.pop("DJ_OP_REGISTRY", None) + try: + path = get_registry_path() + self.assertTrue(str(path).endswith("op_registry.json")) + self.assertIn(".data_juicer", str(path)) + finally: + if env_backup is not None: + os.environ["DJ_OP_REGISTRY"] = env_backup + + def test_env_override(self): + with tempfile.TemporaryDirectory() as tmp: + custom_path = os.path.join(tmp, "custom_reg.json") + os.environ["DJ_OP_REGISTRY"] = custom_path + try: + self.assertEqual(str(get_registry_path()), custom_path) + finally: + del os.environ["DJ_OP_REGISTRY"] + +class ReadWriteRegistryTest(DataJuicerTestCaseBase): + """Test _read_registry / _write_registry helpers.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_read_empty(self): + data = _read_registry() + self.assertEqual(data["version"], 1) + self.assertEqual(data["custom_operators"], {}) + + def test_write_and_read(self): + payload = { + "version": 1, + "custom_operators": { + "foo": {"source_path": "/tmp/foo.py", "registered_at": "2026-01-01T00:00:00"} + }, + } + _write_registry(payload) + data = _read_registry() + self.assertEqual(data, payload) + + def test_read_malformed(self): + with open(self._reg_path, "w") as f: + f.write("not json") + data = _read_registry() + self.assertEqual(data["custom_operators"], {}) + +class RegisterPersistentTest(DataJuicerTestCaseBase): + """Test register_persistent end-to-end.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._op_name = "test_reg_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_register_and_json(self): + result = register_persistent([self._op_file]) + self.assertIn(self._op_name, result["registered"]) + self.assertIn(self._op_name, OPERATORS.modules) + + data = _read_registry() + self.assertIn(self._op_name, data["custom_operators"]) + self.assertEqual( + data["custom_operators"][self._op_name]["source_path"], + os.path.abspath(self._op_file), + ) + + def test_register_nonexistent_path(self): + result = register_persistent(["/no/such/path.py"]) + self.assertEqual(result["registered"], []) + self.assertGreater(len(result["warnings"]), 0) + +class UnregisterOpsTest(DataJuicerTestCaseBase): + """Test unregister_ops.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._op_name = "test_unreg_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + register_persistent([self._op_file]) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_unregister_existing(self): + result = unregister_ops([self._op_name]) + self.assertIn(self._op_name, result["removed"]) + self.assertEqual(result["not_found"], []) + self.assertNotIn(self._op_name, OPERATORS.modules) + + data = _read_registry() + self.assertNotIn(self._op_name, data["custom_operators"]) + + def test_unregister_nonexistent(self): + result = unregister_ops(["no_such_op"]) + self.assertEqual(result["removed"], []) + self.assertIn("no_such_op", result["not_found"]) + +class ResetRegistryTest(DataJuicerTestCaseBase): + """Test reset_registry.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._op_name = "test_reset_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + register_persistent([self._op_file]) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_reset_clears_all(self): + result = reset_registry() + self.assertIn(self._op_name, result["removed"]) + self.assertNotIn(self._op_name, OPERATORS.modules) + + data = _read_registry() + self.assertEqual(data["custom_operators"], {}) + +class ListRegisteredTest(DataJuicerTestCaseBase): + """Test list_registered.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_list_empty(self): + result = list_registered() + self.assertEqual(result["custom_operators"], {}) + + def test_list_after_register(self): + op_name = "test_list_mapper" + op_file = _make_custom_op_file(self._tmp.name, op_name) + register_persistent([op_file]) + try: + result = list_registered() + self.assertIn(op_name, result["custom_operators"]) + finally: + OPERATORS.unregister_module(op_name) + sys.modules.pop(op_name, None) + +class LoadPersistentCustomOpsTest(DataJuicerTestCaseBase): + """Test load_persistent_custom_ops including stale-entry cleanup.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_load_valid(self): + op_name = "test_load_mapper" + op_file = _make_custom_op_file(self._tmp.name, op_name) + register_persistent([op_file]) + + OPERATORS.unregister_module(op_name) + sys.modules.pop(op_name, None) + self.assertNotIn(op_name, OPERATORS.modules) + + result = load_persistent_custom_ops() + self.assertIn(op_name, result["loaded"]) + self.assertIn(op_name, OPERATORS.modules) + + OPERATORS.unregister_module(op_name) + sys.modules.pop(op_name, None) + + def test_load_cleans_stale(self): + payload = { + "version": 1, + "custom_operators": { + "stale_op": { + "source_path": "/no/such/file.py", + "registered_at": "2026-01-01T00:00:00", + } + }, + } + _write_registry(payload) + + result = load_persistent_custom_ops() + self.assertIn("stale_op", result["cleaned"]) + self.assertGreater(len(result["warnings"]), 0) + + data = _read_registry() + self.assertNotIn("stale_op", data["custom_operators"]) + +class CrossProcessVisibilityTest(DataJuicerTestCaseBase): + """Test that a custom op registered in one process is visible in another.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._op_name = "test_xproc_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + register_persistent([self._op_file]) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_cross_process(self): + """Spawn a subprocess that loads the registry and checks the op.""" + script = textwrap.dedent(f"""\ + import os, sys, json + os.environ["DJ_OP_REGISTRY"] = "{self._reg_path}" + from data_juicer.utils.custom_op import load_persistent_custom_ops + from data_juicer.ops.base_op import OPERATORS + result = load_persistent_custom_ops() + if "{self._op_name}" in OPERATORS.modules: + print("OK") + else: + print("FAIL") + sys.exit(1) + """) + proc = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=60, + ) + self.assertEqual(proc.returncode, 0, f"stderr: {proc.stderr}") + self.assertIn("OK", proc.stdout) + +# --------------------------------------------------------------------------- +# CLI tests +# --------------------------------------------------------------------------- + +class CustomOpCLIListTest(DataJuicerTestCaseBase): + """Test the 'list' CLI sub-command.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_list_empty(self): + rc = custom_op_main(["list"]) + self.assertEqual(rc, 0) + +class CustomOpCLIRegisterTest(DataJuicerTestCaseBase): + """Test the 'register' CLI sub-command.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._op_name = "cli_reg_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_register(self): + rc = custom_op_main(["register", self._op_file]) + self.assertEqual(rc, 0) + self.assertIn(self._op_name, OPERATORS.modules) + +class CustomOpCLIUnregisterTest(DataJuicerTestCaseBase): + """Test the 'unregister' CLI sub-command.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._op_name = "cli_unreg_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + register_persistent([self._op_file]) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_unregister(self): + rc = custom_op_main(["unregister", self._op_name]) + self.assertEqual(rc, 0) + self.assertNotIn(self._op_name, OPERATORS.modules) + +class CustomOpCLIResetTest(DataJuicerTestCaseBase): + """Test the 'reset' CLI sub-command.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "op_registry.json") + os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._op_name = "cli_reset_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + register_persistent([self._op_file]) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_reset(self): + rc = custom_op_main(["reset"]) + self.assertEqual(rc, 0) + self.assertNotIn(self._op_name, OPERATORS.modules) + +class CustomOpCLINoCommandTest(DataJuicerTestCaseBase): + """Test calling custom_op CLI with no command.""" + + def test_no_command(self): + rc = custom_op_main([]) + self.assertEqual(rc, 1) + +if __name__ == "__main__": + unittest.main() From 0e6a2166011828005a80b0a699d0c8298b7984e0 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Wed, 15 Apr 2026 11:34:15 +0800 Subject: [PATCH 2/8] refactor: simplify custom operator loading logic --- data_juicer/config/config.py | 54 +++++------------------------ data_juicer/utils/custom_op.py | 62 ++++++++++++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 49 deletions(-) diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index d39b11e947..e4dbabe3ac 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -1,6 +1,5 @@ import argparse import copy -import importlib.util import json import os import shutil @@ -51,51 +50,14 @@ def _generate_module_name(abs_path): def load_custom_operators(paths): - """Dynamically load custom operator modules or packages in the specified path.""" - for path in paths: - abs_path = os.path.abspath(path) - if os.path.isfile(abs_path): - module_name = _generate_module_name(abs_path) - if module_name in sys.modules: - existing_path = sys.modules[module_name].__file__ - raise RuntimeError( - f"Module '{module_name}' already loaded from '{existing_path}'. " - f"Conflict detected while loading '{abs_path}'." - ) - try: - spec = importlib.util.spec_from_file_location(module_name, abs_path) - if spec is None: - raise RuntimeError(f"Failed to create spec for '{abs_path}'") - module = importlib.util.module_from_spec(spec) - # register the module first to avoid recursive import issues - sys.modules[module_name] = module - spec.loader.exec_module(module) - except Exception as e: - raise RuntimeError(f"Error loading '{abs_path}' as '{module_name}': {e}") - - elif os.path.isdir(abs_path): - if not os.path.isfile(os.path.join(abs_path, "__init__.py")): - raise ValueError(f"Package directory '{abs_path}' must contain __init__.py") - package_name = os.path.basename(abs_path) - parent_dir = os.path.dirname(abs_path) - if package_name in sys.modules: - existing_path = sys.modules[package_name].__path__[0] - raise RuntimeError( - f"Package '{package_name}' already loaded from '{existing_path}'. " - f"Conflict detected while loading '{abs_path}'." - ) - original_sys_path = sys.path.copy() - try: - sys.path.insert(0, parent_dir) - importlib.import_module(package_name) - # record the loading path of the package (for subsequent conflict detection) - sys.modules[package_name].__loaded_from__ = abs_path - except Exception as e: - raise RuntimeError(f"Error loading package '{abs_path}': {e}") - finally: - sys.path = original_sys_path - else: - raise ValueError(f"Path '{abs_path}' is neither a file nor a directory") + """Dynamically load custom operator modules or packages in the specified path. + + This is a re-export from ``data_juicer.utils.custom_op`` kept here for + backward compatibility. + """ + from data_juicer.utils.custom_op import load_custom_operators as _impl + + _impl(paths) def build_base_parser() -> ArgumentParser: diff --git a/data_juicer/utils/custom_op.py b/data_juicer/utils/custom_op.py index 3d834393db..fec7a46db9 100644 --- a/data_juicer/utils/custom_op.py +++ b/data_juicer/utils/custom_op.py @@ -19,6 +19,7 @@ """ import argparse +import importlib.util import inspect import json import os @@ -30,6 +31,64 @@ from loguru import logger +# --------------------------------------------------------------------------- +# Dynamic module loading +# --------------------------------------------------------------------------- + + +def _generate_module_name(abs_path): + """Generate a module name based on the absolute path of the file.""" + return os.path.splitext(os.path.basename(abs_path))[0] + + +def load_custom_operators(paths): + """Dynamically load custom operator modules or packages in the specified path.""" + for path in paths: + abs_path = os.path.abspath(path) + if os.path.isfile(abs_path): + module_name = _generate_module_name(abs_path) + if module_name in sys.modules: + existing_path = sys.modules[module_name].__file__ + raise RuntimeError( + f"Module '{module_name}' already loaded from '{existing_path}'. " + f"Conflict detected while loading '{abs_path}'." + ) + try: + spec = importlib.util.spec_from_file_location(module_name, abs_path) + if spec is None: + raise RuntimeError(f"Failed to create spec for '{abs_path}'") + module = importlib.util.module_from_spec(spec) + # register the module first to avoid recursive import issues + sys.modules[module_name] = module + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f"Error loading '{abs_path}' as '{module_name}': {e}") + + elif os.path.isdir(abs_path): + if not os.path.isfile(os.path.join(abs_path, "__init__.py")): + raise ValueError(f"Package directory '{abs_path}' must contain __init__.py") + package_name = os.path.basename(abs_path) + parent_dir = os.path.dirname(abs_path) + if package_name in sys.modules: + existing_path = sys.modules[package_name].__path__[0] + raise RuntimeError( + f"Package '{package_name}' already loaded from '{existing_path}'. " + f"Conflict detected while loading '{abs_path}'." + ) + original_sys_path = sys.path.copy() + try: + sys.path.insert(0, parent_dir) + importlib.import_module(package_name) + # record the loading path of the package (for subsequent conflict detection) + sys.modules[package_name].__loaded_from__ = abs_path + except Exception as e: + raise RuntimeError(f"Error loading package '{abs_path}': {e}") + finally: + sys.path = original_sys_path + else: + raise ValueError(f"Path '{abs_path}' is neither a file nor a directory") + + # --------------------------------------------------------------------------- # Path management # --------------------------------------------------------------------------- @@ -110,7 +169,6 @@ def register_persistent(paths: List[str]) -> dict: Returns ``{"registered": [...], "warnings": [...]}``. """ - from data_juicer.config.config import load_custom_operators from data_juicer.ops.base_op import OPERATORS # Snapshot operator names before loading. @@ -270,8 +328,6 @@ def load_persistent_custom_ops() -> dict: Returns ``{"loaded": [...], "cleaned": [...], "warnings": [...]}``. """ - from data_juicer.config.config import load_custom_operators - registry = _read_registry() custom_ops = registry.get("custom_operators", {}) From de1e57419e76a81d6363ee85dae2324f2c9e82b7 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Tue, 14 Apr 2026 12:17:10 +0800 Subject: [PATCH 3/8] refactor: enhance signature annotation resolution for ops search functionality --- data_juicer/tools/DJ_mcp_granular_ops.py | 46 +++++++++++++++++++++--- data_juicer/tools/op_search.py | 2 ++ 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/data_juicer/tools/DJ_mcp_granular_ops.py b/data_juicer/tools/DJ_mcp_granular_ops.py index 1ebd897e40..61d2bd0fec 100644 --- a/data_juicer/tools/DJ_mcp_granular_ops.py +++ b/data_juicer/tools/DJ_mcp_granular_ops.py @@ -2,7 +2,7 @@ import inspect import os import sys -from typing import Annotated, Optional +from typing import Annotated, Optional, get_type_hints from pydantic import Field @@ -13,6 +13,30 @@ fastmcp = LazyLoader("mcp.server.fastmcp", "mcp[cli]") +def resolve_signature_annotations(func, sig: inspect.Signature) -> inspect.Signature: + """Resolve postponed/string annotations into real runtime types. + + When a module uses ``from __future__ import annotations``, all + annotations are stored as strings. This helper calls + ``typing.get_type_hints`` on the original callable to obtain the + real type objects and rebuilds the signature with them. + """ + try: + module = sys.modules.get(func.__module__, None) if hasattr(func, "__module__") else None + globalns = module.__dict__ if module else {} + hints = get_type_hints(func, globalns=globalns, localns=globalns) + except Exception: + hints = {} + + new_params = [] + for name, param in sig.parameters.items(): + resolved_annotation = hints.get(name, param.annotation) + new_params.append(param.replace(annotation=resolved_annotation)) + + return_annotation = hints.get("return", sig.return_annotation) + return sig.replace(parameters=new_params, return_annotation=return_annotation) + + # Dynamic MCP Tool Creation def process_parameter(name: str, param: inspect.Parameter) -> inspect.Parameter: """ @@ -31,13 +55,18 @@ def create_operator_function(op, mcp): This function dynamically creates a function that can be registered as an MCP tool, with proper signature and documentation based on the operator's __init__ method. """ - sig = op["sig"] + raw_sig = op["sig"] + init_func = op.get("init_func") + if init_func is not None: + sig = resolve_signature_annotations(init_func, raw_sig) + else: + sig = raw_sig docstring = op["desc"] param_docstring = op["param_desc"] # Create new function signature with dataset_path as first parameter # Consider adding other common parameters later, such as export_psth - new_parameters = [ + fixed_params = [ inspect.Parameter("dataset_path", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str), inspect.Parameter( "export_path", @@ -51,11 +80,18 @@ def create_operator_function(op, mcp): annotation=Optional[int], default=None, ), - ] + [ + ] + op_params = [ process_parameter(name, param) for name, param in sig.parameters.items() if name not in ("args", "kwargs", "self") ] + # Merge all params, then reorder: required (no default) first, + # optional (with default) second, to satisfy Python's signature rule. + all_params = fixed_params + op_params + required_params = [p for p in all_params if p.default is inspect.Parameter.empty] + optional_params = [p for p in all_params if p.default is not inspect.Parameter.empty] + new_parameters = required_params + optional_params new_signature = sig.replace(parameters=new_parameters, return_annotation=str) def func(*args, **kwargs): @@ -66,7 +102,7 @@ def func(*args, **kwargs): export_path = bound_arguments.arguments.pop("export_path") dataset_path = bound_arguments.arguments.pop("dataset_path") np = bound_arguments.arguments.pop("np") - args_dict = {k: v for k, v in bound_arguments.arguments.items() if v} + args_dict = {k: v for k, v in bound_arguments.arguments.items() if v is not None} dj_cfg = { "dataset_path": dataset_path, diff --git a/data_juicer/tools/op_search.py b/data_juicer/tools/op_search.py index 32cda3f0d9..416b6e12e1 100644 --- a/data_juicer/tools/op_search.py +++ b/data_juicer/tools/op_search.py @@ -176,6 +176,7 @@ def __init__(self, name: str, op_cls: type, op_type: Optional[str] = None): self.desc = op_cls.__doc__ or "" self.tags = analyze_tag_from_cls(op_cls, name) self.sig = inspect.signature(op_cls.__init__) + self.init_func = op_cls.__init__ self.param_desc = extract_param_docstring(op_cls.__init__.__doc__ or "") self.param_desc_map = self._parse_param_desc() @@ -229,6 +230,7 @@ def to_dict(self): "desc": self.desc, "tags": self.tags, "sig": self.sig, + "init_func": self.init_func, "param_desc": self.param_desc, "param_desc_map": self.param_desc_map, "source_path": self.source_path, From 63d53f372c43118682965e33d046cb2d2a74e6d0 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Wed, 15 Apr 2026 11:51:38 +0800 Subject: [PATCH 4/8] refactor: rename custom op registry file path --- data_juicer/ops/__init__.py | 2 +- data_juicer/utils/custom_op.py | 10 ++--- tests/tools/test_op_search.py | 18 ++++---- tests/utils/test_custom_op.py | 78 +++++++++++++++++----------------- 4 files changed, 54 insertions(+), 54 deletions(-) diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py index b151807c0f..cbfc23ad94 100644 --- a/data_juicer/ops/__init__.py +++ b/data_juicer/ops/__init__.py @@ -16,7 +16,7 @@ def timing_context(description): with timing_context('Importing operator modules'): # 1. Built-in operators (registered via @OPERATORS.register_module decorators # that fire as each sub-package is imported) - # 2. Persistent custom operators (loaded from ~/.data_juicer/op_registry.json; + # 2. Persistent custom operators (loaded from ~/.data_juicer/custom_op.json; # no-op when the registry file does not exist) from data_juicer.utils.custom_op import ( load_persistent_custom_ops as _load_persistent, diff --git a/data_juicer/utils/custom_op.py b/data_juicer/utils/custom_op.py index fec7a46db9..5b2453a3ce 100644 --- a/data_juicer/utils/custom_op.py +++ b/data_juicer/utils/custom_op.py @@ -2,7 +2,7 @@ Custom Operator Management =========================== -Manages a persistent JSON registry at ``~/.data_juicer/op_registry.json`` +Manages a persistent JSON registry at ``~/.data_juicer/custom_op.json`` so that custom operators survive across processes. The registry file stores source-file paths keyed by operator name. @@ -97,13 +97,13 @@ def load_custom_operators(paths): def get_registry_path() -> Path: """Return the path to the persistent op registry JSON file. - Defaults to ``~/.data_juicer/op_registry.json``. - Override with the ``DJ_OP_REGISTRY`` environment variable. + Defaults to ``~/.data_juicer/custom_op.json``. + Override with the ``DJ_CUSTOM_OP_REGISTRY`` environment variable. """ - override = os.environ.get("DJ_OP_REGISTRY") + override = os.environ.get("DJ_CUSTOM_OP_REGISTRY") if override: return Path(override) - return Path.home() / ".data_juicer" / "op_registry.json" + return Path.home() / ".data_juicer" / "custom_op.json" # --------------------------------------------------------------------------- diff --git a/tests/tools/test_op_search.py b/tests/tools/test_op_search.py index c684f83962..4de833f31e 100644 --- a/tests/tools/test_op_search.py +++ b/tests/tools/test_op_search.py @@ -263,11 +263,11 @@ class OpSearchCLIListTest(DataJuicerTestCaseBase): def setUp(self): self._tmp_dir = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp_dir.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp_dir.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path def tearDown(self): - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp_dir.cleanup() def test_list(self): @@ -279,11 +279,11 @@ class OpSearchCLIInfoTest(DataJuicerTestCaseBase): def setUp(self): self._tmp_dir = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp_dir.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp_dir.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path def tearDown(self): - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp_dir.cleanup() def test_info_builtin(self): @@ -300,11 +300,11 @@ class OpSearchCLISearchTest(DataJuicerTestCaseBase): def setUp(self): self._tmp_dir = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp_dir.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp_dir.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path def tearDown(self): - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp_dir.cleanup() def test_search_bm25(self): diff --git a/tests/utils/test_custom_op.py b/tests/utils/test_custom_op.py index a9e54bf24d..244369a5ad 100644 --- a/tests/utils/test_custom_op.py +++ b/tests/utils/test_custom_op.py @@ -78,34 +78,34 @@ class RegistryPathTest(DataJuicerTestCaseBase): """Test get_registry_path with and without env override.""" def test_default_path(self): - env_backup = os.environ.pop("DJ_OP_REGISTRY", None) + env_backup = os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) try: path = get_registry_path() - self.assertTrue(str(path).endswith("op_registry.json")) + self.assertTrue(str(path).endswith("custom_op.json")) self.assertIn(".data_juicer", str(path)) finally: if env_backup is not None: - os.environ["DJ_OP_REGISTRY"] = env_backup + os.environ["DJ_CUSTOM_OP_REGISTRY"] = env_backup def test_env_override(self): with tempfile.TemporaryDirectory() as tmp: custom_path = os.path.join(tmp, "custom_reg.json") - os.environ["DJ_OP_REGISTRY"] = custom_path + os.environ["DJ_CUSTOM_OP_REGISTRY"] = custom_path try: self.assertEqual(str(get_registry_path()), custom_path) finally: - del os.environ["DJ_OP_REGISTRY"] + del os.environ["DJ_CUSTOM_OP_REGISTRY"] class ReadWriteRegistryTest(DataJuicerTestCaseBase): """Test _read_registry / _write_registry helpers.""" def setUp(self): self._tmp = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path def tearDown(self): - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp.cleanup() def test_read_empty(self): @@ -135,15 +135,15 @@ class RegisterPersistentTest(DataJuicerTestCaseBase): def setUp(self): self._tmp = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path self._op_name = "test_reg_mapper" self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) def tearDown(self): OPERATORS.unregister_module(self._op_name) sys.modules.pop(self._op_name, None) - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp.cleanup() def test_register_and_json(self): @@ -168,8 +168,8 @@ class UnregisterOpsTest(DataJuicerTestCaseBase): def setUp(self): self._tmp = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path self._op_name = "test_unreg_mapper" self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) register_persistent([self._op_file]) @@ -177,7 +177,7 @@ def setUp(self): def tearDown(self): OPERATORS.unregister_module(self._op_name) sys.modules.pop(self._op_name, None) - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp.cleanup() def test_unregister_existing(self): @@ -199,8 +199,8 @@ class ResetRegistryTest(DataJuicerTestCaseBase): def setUp(self): self._tmp = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path self._op_name = "test_reset_mapper" self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) register_persistent([self._op_file]) @@ -208,7 +208,7 @@ def setUp(self): def tearDown(self): OPERATORS.unregister_module(self._op_name) sys.modules.pop(self._op_name, None) - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp.cleanup() def test_reset_clears_all(self): @@ -224,11 +224,11 @@ class ListRegisteredTest(DataJuicerTestCaseBase): def setUp(self): self._tmp = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path def tearDown(self): - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp.cleanup() def test_list_empty(self): @@ -251,11 +251,11 @@ class LoadPersistentCustomOpsTest(DataJuicerTestCaseBase): def setUp(self): self._tmp = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path def tearDown(self): - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp.cleanup() def test_load_valid(self): @@ -298,8 +298,8 @@ class CrossProcessVisibilityTest(DataJuicerTestCaseBase): def setUp(self): self._tmp = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path self._op_name = "test_xproc_mapper" self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) register_persistent([self._op_file]) @@ -307,14 +307,14 @@ def setUp(self): def tearDown(self): OPERATORS.unregister_module(self._op_name) sys.modules.pop(self._op_name, None) - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp.cleanup() def test_cross_process(self): """Spawn a subprocess that loads the registry and checks the op.""" script = textwrap.dedent(f"""\ import os, sys, json - os.environ["DJ_OP_REGISTRY"] = "{self._reg_path}" + os.environ["DJ_CUSTOM_OP_REGISTRY"] = "{self._reg_path}" from data_juicer.utils.custom_op import load_persistent_custom_ops from data_juicer.ops.base_op import OPERATORS result = load_persistent_custom_ops() @@ -342,11 +342,11 @@ class CustomOpCLIListTest(DataJuicerTestCaseBase): def setUp(self): self._tmp = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path def tearDown(self): - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp.cleanup() def test_list_empty(self): @@ -358,15 +358,15 @@ class CustomOpCLIRegisterTest(DataJuicerTestCaseBase): def setUp(self): self._tmp = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path self._op_name = "cli_reg_mapper" self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) def tearDown(self): OPERATORS.unregister_module(self._op_name) sys.modules.pop(self._op_name, None) - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp.cleanup() def test_register(self): @@ -379,8 +379,8 @@ class CustomOpCLIUnregisterTest(DataJuicerTestCaseBase): def setUp(self): self._tmp = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path self._op_name = "cli_unreg_mapper" self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) register_persistent([self._op_file]) @@ -388,7 +388,7 @@ def setUp(self): def tearDown(self): OPERATORS.unregister_module(self._op_name) sys.modules.pop(self._op_name, None) - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp.cleanup() def test_unregister(self): @@ -401,8 +401,8 @@ class CustomOpCLIResetTest(DataJuicerTestCaseBase): def setUp(self): self._tmp = tempfile.TemporaryDirectory() - self._reg_path = os.path.join(self._tmp.name, "op_registry.json") - os.environ["DJ_OP_REGISTRY"] = self._reg_path + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path self._op_name = "cli_reset_mapper" self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) register_persistent([self._op_file]) @@ -410,7 +410,7 @@ def setUp(self): def tearDown(self): OPERATORS.unregister_module(self._op_name) sys.modules.pop(self._op_name, None) - os.environ.pop("DJ_OP_REGISTRY", None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) self._tmp.cleanup() def test_reset(self): From ea55f7d158314045631a38068043e4f3c14fed2b Mon Sep 17 00:00:00 2001 From: cmgzn Date: Wed, 15 Apr 2026 11:58:39 +0800 Subject: [PATCH 5/8] docs: update registry path and env var names in Developer Guide --- docs/DeveloperGuide.md | 2 +- docs/DeveloperGuide_ZH.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index d1243b3bfd..793e732124 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -33,7 +33,7 @@ In addition to the per-run `--custom-operator-paths` approach above, Data-Juicer provides a **persistent custom operator registry** so that externally developed operators survive across processes and sessions without repeating configuration. -The registry is stored at `~/.data_juicer/op_registry.json` (override with the `DJ_OP_REGISTRY` environment variable). Manage it via the CLI: +The registry is stored at `~/.data_juicer/custom_op.json` (override with the `DJ_CUSTOM_OP_REGISTRY` environment variable). Manage it via the CLI: ```bash # Register custom operator(s) — accepts file or directory paths diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index 6623738148..8533de1571 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -33,7 +33,7 @@ 除了上述每次运行时指定 `--custom-operator-paths` 的方式外,Data-Juicer 还提供了**持久化自定义算子注册表**,使外部开发的算子能够跨进程、跨会话持续生效,无需重复配置。 -注册表存储在 `~/.data_juicer/op_registry.json`(可通过环境变量 `DJ_OP_REGISTRY` 覆盖路径)。通过 CLI 管理: +注册表存储在 `~/.data_juicer/custom_op.json`(可通过环境变量 `DJ_CUSTOM_OP_REGISTRY` 覆盖路径)。通过 CLI 管理: ```bash # 注册自定义算子 — 支持文件或目录路径 From 18b437d4a15acab3e128f32810986f31896d839a Mon Sep 17 00:00:00 2001 From: cmgzn Date: Wed, 15 Apr 2026 12:07:36 +0800 Subject: [PATCH 6/8] fix: enhance error handling for op search paths --- data_juicer/tools/op_search.py | 7 +++++-- data_juicer/utils/custom_op.py | 31 +++++++++++++++++++------------ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/data_juicer/tools/op_search.py b/data_juicer/tools/op_search.py index 416b6e12e1..49a224b613 100644 --- a/data_juicer/tools/op_search.py +++ b/data_juicer/tools/op_search.py @@ -183,8 +183,11 @@ def __init__(self, name: str, op_cls: type, op_type: Optional[str] = None): # --- source path: handling for custom ops --- try: self.source_path = str(get_source_path(op_cls)) - except ValueError: - self.source_path = str(Path(inspect.getfile(op_cls))) + except (ValueError, TypeError, OSError): + try: + self.source_path = str(Path(inspect.getfile(op_cls))) + except (TypeError, OSError): + self.source_path = "unknown" # --- test path: handling for custom ops --- try: diff --git a/data_juicer/utils/custom_op.py b/data_juicer/utils/custom_op.py index 5b2453a3ce..f8cae49769 100644 --- a/data_juicer/utils/custom_op.py +++ b/data_juicer/utils/custom_op.py @@ -259,20 +259,24 @@ def unregister_ops(names: List[str]) -> dict: for name in names: if name in registry["custom_operators"]: + meta = registry["custom_operators"][name] del registry["custom_operators"][name] removed.append(name) + + # Also remove from in-process registry. + OPERATORS.unregister_module(name) + + # Remove from sys.modules using the source file stem for + # precise matching — avoids accidentally removing built-in + # modules that happen to share a suffix. + src = meta.get("source_path") + if src: + module_name = os.path.splitext(os.path.basename(src))[0] + if module_name in sys.modules: + del sys.modules[module_name] else: not_found.append(name) - # Also remove from in-process registry. - OPERATORS.unregister_module(name) - - # Remove from sys.modules if present. - # The module name is typically the stem of the source file. - to_remove = [k for k in sys.modules if k == name or k.endswith(f".{name}")] - for k in to_remove: - del sys.modules[k] - _write_registry(registry) return {"removed": removed, "not_found": not_found} @@ -288,11 +292,14 @@ def reset_registry() -> dict: registry = _read_registry() removed = sorted(registry["custom_operators"].keys()) + custom_ops = registry["custom_operators"] for name in removed: OPERATORS.unregister_module(name) - to_remove = [k for k in sys.modules if k == name or k.endswith(f".{name}")] - for k in to_remove: - del sys.modules[k] + meta = custom_ops.get(name) + if meta and "source_path" in meta: + module_name = os.path.splitext(os.path.basename(meta["source_path"]))[0] + if module_name in sys.modules: + del sys.modules[module_name] registry["custom_operators"] = {} _write_registry(registry) From 56d6fb477c5a0df3a02d81efce93b02cba114cfa Mon Sep 17 00:00:00 2001 From: cmgzn Date: Wed, 15 Apr 2026 13:26:48 +0800 Subject: [PATCH 7/8] correct typo and improve error handling for custom operators loading --- data_juicer/tools/DJ_mcp_granular_ops.py | 2 +- data_juicer/utils/custom_op.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/data_juicer/tools/DJ_mcp_granular_ops.py b/data_juicer/tools/DJ_mcp_granular_ops.py index 61d2bd0fec..7fe4532916 100644 --- a/data_juicer/tools/DJ_mcp_granular_ops.py +++ b/data_juicer/tools/DJ_mcp_granular_ops.py @@ -65,7 +65,7 @@ def create_operator_function(op, mcp): param_docstring = op["param_desc"] # Create new function signature with dataset_path as first parameter - # Consider adding other common parameters later, such as export_psth + # Consider adding other common parameters later, such as export_path fixed_params = [ inspect.Parameter("dataset_path", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str), inspect.Parameter( diff --git a/data_juicer/utils/custom_op.py b/data_juicer/utils/custom_op.py index f8cae49769..9a75db2514 100644 --- a/data_juicer/utils/custom_op.py +++ b/data_juicer/utils/custom_op.py @@ -62,6 +62,8 @@ def load_custom_operators(paths): sys.modules[module_name] = module spec.loader.exec_module(module) except Exception as e: + # Clean up partially-initialized module to avoid stale entries + sys.modules.pop(module_name, None) raise RuntimeError(f"Error loading '{abs_path}' as '{module_name}': {e}") elif os.path.isdir(abs_path): From 1e075cd6db78b6596027549ae80b6e17a08c36d6 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Thu, 16 Apr 2026 17:36:41 +0800 Subject: [PATCH 8/8] refactor: update custom operator management to use registration paths instead of operator names --- data_juicer/config/config.py | 5 - data_juicer/ops/__init__.py | 11 +- data_juicer/tools/DJ_mcp_granular_ops.py | 2 +- data_juicer/utils/custom_op.py | 432 ++++++++++++++--------- docs/DeveloperGuide.md | 4 +- docs/DeveloperGuide_ZH.md | 4 +- tests/utils/test_custom_op.py | 92 +++-- 7 files changed, 344 insertions(+), 206 deletions(-) diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index e4dbabe3ac..8e4c5549c2 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -44,11 +44,6 @@ def timing_context(description): logger.debug(f"{description} took {elapsed_time:.2f} seconds") -def _generate_module_name(abs_path): - """Generate a module name based on the absolute path of the file.""" - return os.path.splitext(os.path.basename(abs_path))[0] - - def load_custom_operators(paths): """Dynamically load custom operator modules or packages in the specified path. diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py index cbfc23ad94..373d0e86ae 100644 --- a/data_juicer/ops/__init__.py +++ b/data_juicer/ops/__init__.py @@ -18,14 +18,7 @@ def timing_context(description): # that fire as each sub-package is imported) # 2. Persistent custom operators (loaded from ~/.data_juicer/custom_op.json; # no-op when the registry file does not exist) - from data_juicer.utils.custom_op import ( - load_persistent_custom_ops as _load_persistent, - ) - from . import aggregator, deduplicator, filter, grouper, mapper, pipeline, selector - _load_persistent() - del _load_persistent - from .base_op import ( ATTRIBUTION_FILTERS, NON_STATS_FILTERS, @@ -49,6 +42,10 @@ def timing_context(description): op_requirements_to_op_env_spec, ) + from data_juicer.utils.custom_op import load_persistent_custom_ops as _load_persistent # isort: skip # noqa: E501 + _load_persistent() + del _load_persistent + __all__ = [ "load_ops", "Filter", diff --git a/data_juicer/tools/DJ_mcp_granular_ops.py b/data_juicer/tools/DJ_mcp_granular_ops.py index 7fe4532916..89cde83508 100644 --- a/data_juicer/tools/DJ_mcp_granular_ops.py +++ b/data_juicer/tools/DJ_mcp_granular_ops.py @@ -24,7 +24,7 @@ def resolve_signature_annotations(func, sig: inspect.Signature) -> inspect.Signa try: module = sys.modules.get(func.__module__, None) if hasattr(func, "__module__") else None globalns = module.__dict__ if module else {} - hints = get_type_hints(func, globalns=globalns, localns=globalns) + hints = get_type_hints(func, globalns=globalns) except Exception: hints = {} diff --git a/data_juicer/utils/custom_op.py b/data_juicer/utils/custom_op.py index 9a75db2514..0d1771d354 100644 --- a/data_juicer/utils/custom_op.py +++ b/data_juicer/utils/custom_op.py @@ -5,22 +5,25 @@ Manages a persistent JSON registry at ``~/.data_juicer/custom_op.json`` so that custom operators survive across processes. -The registry file stores source-file paths keyed by operator name. +The registry stores only **registration paths** (file or directory) as +primary keys. It does **not** cache operator names — the actual operator +list is always derived at runtime from the in-process ``OPERATORS`` +registry after loading, so it is always up-to-date with the file system. + On startup, ``load_persistent_custom_ops()`` replays those registrations into the in-process ``OPERATORS`` registry, automatically cleaning up -entries whose source files no longer exist. +entries whose source paths no longer exist. Also provides a CLI for managing custom operators:: python -m data_juicer.utils.custom_op list python -m data_juicer.utils.custom_op register /path/to/my_mapper.py - python -m data_juicer.utils.custom_op unregister my_mapper + python -m data_juicer.utils.custom_op unregister /path/to/my_mapper.py python -m data_juicer.utils.custom_op reset """ import argparse import importlib.util -import inspect import json import os import sys @@ -41,10 +44,25 @@ def _generate_module_name(abs_path): return os.path.splitext(os.path.basename(abs_path))[0] +def _rollback_operators(new_names): + """Remove the given operator names from OPERATORS. + + Used to roll back a partially-successful load: *new_names* is the set + of operators that were registered after the snapshot but before the + error occurred. + """ + try: + from data_juicer.ops.base_op import OPERATORS + except ImportError: + return + for name in new_names: + OPERATORS.unregister_module(name) + + def load_custom_operators(paths): """Dynamically load custom operator modules or packages in the specified path.""" for path in paths: - abs_path = os.path.abspath(path) + abs_path = os.path.realpath(path) if os.path.isfile(abs_path): module_name = _generate_module_name(abs_path) if module_name in sys.modules: @@ -53,9 +71,12 @@ def load_custom_operators(paths): f"Module '{module_name}' already loaded from '{existing_path}'. " f"Conflict detected while loading '{abs_path}'." ) + from data_juicer.ops.base_op import OPERATORS + + ops_before = set(OPERATORS.modules.keys()) try: spec = importlib.util.spec_from_file_location(module_name, abs_path) - if spec is None: + if spec is None or spec.loader is None: raise RuntimeError(f"Failed to create spec for '{abs_path}'") module = importlib.util.module_from_spec(spec) # register the module first to avoid recursive import issues @@ -64,6 +85,8 @@ def load_custom_operators(paths): except Exception as e: # Clean up partially-initialized module to avoid stale entries sys.modules.pop(module_name, None) + ops_after = set(OPERATORS.modules.keys()) + _rollback_operators(ops_after - ops_before) raise RuntimeError(f"Error loading '{abs_path}' as '{module_name}': {e}") elif os.path.isdir(abs_path): @@ -77,13 +100,18 @@ def load_custom_operators(paths): f"Package '{package_name}' already loaded from '{existing_path}'. " f"Conflict detected while loading '{abs_path}'." ) + from data_juicer.ops.base_op import OPERATORS + + ops_before = set(OPERATORS.modules.keys()) original_sys_path = sys.path.copy() try: sys.path.insert(0, parent_dir) importlib.import_module(package_name) - # record the loading path of the package (for subsequent conflict detection) - sys.modules[package_name].__loaded_from__ = abs_path + # record the loading path of the package (custom attribute) + setattr(sys.modules[package_name], "__loaded_from__", abs_path) except Exception as e: + ops_after = set(OPERATORS.modules.keys()) + _rollback_operators(ops_after - ops_before) raise RuntimeError(f"Error loading package '{abs_path}': {e}") finally: sys.path = original_sys_path @@ -111,26 +139,28 @@ def get_registry_path() -> Path: # --------------------------------------------------------------------------- # Low-level read / write helpers # --------------------------------------------------------------------------- - -_EMPTY_REGISTRY: Dict = {"version": 1, "custom_operators": {}} +def _empty_registry() -> dict: + """Return a fresh empty registry structure.""" + return {"version": 2, "registrations": {}} def _read_registry() -> dict: """Read the JSON registry. Returns the empty structure when the file - does not exist or is malformed.""" + does not exist or is malformed. + """ path = get_registry_path() if not path.exists(): - return json.loads(json.dumps(_EMPTY_REGISTRY)) # deep copy + return _empty_registry() try: with open(path, "r", encoding="utf-8") as fh: data = json.load(fh) - if not isinstance(data, dict) or "custom_operators" not in data: + if not isinstance(data, dict) or "registrations" not in data: logger.warning(f"Malformed op registry at {path}, resetting to empty.") - return json.loads(json.dumps(_EMPTY_REGISTRY)) + return _empty_registry() return data except (json.JSONDecodeError, OSError) as exc: logger.warning(f"Failed to read op registry at {path}: {exc}") - return json.loads(json.dumps(_EMPTY_REGISTRY)) + return _empty_registry() def _write_registry(data: dict) -> None: @@ -157,6 +187,66 @@ def _write_registry(data: dict) -> None: raise +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _collect_modules_for_path(abs_path: str) -> List[str]: + """Return the sys.modules keys that were loaded from *abs_path*. + + For a single file, this is the module whose ``__file__`` matches. + For a directory/package, this includes all modules whose ``__file__`` + lives under *abs_path*. + """ + result = [] + for mod_name, mod in list(sys.modules.items()): + mod_file = getattr(mod, "__file__", None) + if mod_file is None: + continue + try: + mod_file = os.path.realpath(mod_file) + except (TypeError, OSError): + continue + if os.path.isfile(abs_path): + if mod_file == abs_path: + result.append(mod_name) + elif os.path.isdir(abs_path): + if mod_file.startswith(abs_path + os.sep) or mod_file == abs_path: + result.append(mod_name) + return result + + +def _ops_for_path(abs_path: str) -> List[str]: + """Return the OPERATORS names whose class was defined under *abs_path*. + + Inspects each registered operator class to find its source file and + checks whether it matches *abs_path* (file) or lives under it + (directory). Only called during unregister/list — not on the hot path. + """ + import inspect + + try: + from data_juicer.ops.base_op import OPERATORS + except ImportError: + return [] + + abs_path = os.path.realpath(abs_path) + result = [] + for name, cls in list(OPERATORS.modules.items()): + try: + cls_file = os.path.realpath(inspect.getfile(cls)) + except (TypeError, OSError): + continue + if os.path.isfile(abs_path): + if cls_file == abs_path: + result.append(name) + elif os.path.isdir(abs_path): + if cls_file.startswith(abs_path + os.sep): + result.append(name) + return sorted(result) + + # --------------------------------------------------------------------------- # Custom Op management # --------------------------------------------------------------------------- @@ -167,120 +257,103 @@ def register_persistent(paths: List[str]) -> dict: current-process ``OPERATORS``. *paths* is a list of file / directory paths (same semantics as - ``load_custom_operators``). + ``load_custom_operators``). Each path becomes a top-level key in the + registry so that unregister and reload operate at the same granularity. - Returns ``{"registered": [...], "warnings": [...]}``. + If a path is already registered it is skipped (idempotent). + + Returns ``{"registered": [...], "skipped": [...], "warnings": [...]}``. """ from data_juicer.ops.base_op import OPERATORS - # Snapshot operator names before loading. - before = set(OPERATORS.modules.keys()) - warnings: List[str] = [] + skipped: List[str] = [] valid_paths: List[str] = [] + + registry = _read_registry() + existing_paths = set(registry.get("registrations", {}).keys()) + for p in paths: - abs_p = os.path.abspath(p) + abs_p = os.path.realpath(p) if not os.path.exists(abs_p): warnings.append(f"Path does not exist: {abs_p}") continue + if abs_p in existing_paths: + skipped.append(abs_p) + continue valid_paths.append(abs_p) + # Snapshot operator names before loading. + before = set(OPERATORS.modules.keys()) + if valid_paths: load_custom_operators(valid_paths) # Diff to find newly registered names. after = set(OPERATORS.modules.keys()) - new_names = sorted(after - before) - - # Build a mapping from op name -> source path. - name_to_path: Dict[str, str] = {} - for p in valid_paths: - abs_p = os.path.abspath(p) - if os.path.isfile(abs_p): - # Find which new names came from this file. - for n in new_names: - op_cls = OPERATORS.modules.get(n) - if op_cls is not None: - try: - cls_file = os.path.abspath(inspect.getfile(op_cls)) - if cls_file == abs_p: - name_to_path[n] = abs_p - except (TypeError, OSError): - pass - elif os.path.isdir(abs_p): - # Directory – inspect each new op to find its file. - for n in new_names: - if n in name_to_path: - continue - op_cls = OPERATORS.modules.get(n) - if op_cls is not None: - try: - cls_file = os.path.abspath(inspect.getfile(op_cls)) - if cls_file.startswith(abs_p): - name_to_path[n] = cls_file - except (TypeError, OSError): - pass - - # Fallback: for any new name not yet mapped, try inspect. - for n in new_names: - if n not in name_to_path: - op_cls = OPERATORS.modules.get(n) - if op_cls is not None: - try: - name_to_path[n] = os.path.abspath(inspect.getfile(op_cls)) - except (TypeError, OSError): - warnings.append(f"Could not determine source path for '{n}'") - - # Persist. - registry = _read_registry() - now = datetime.now().isoformat(timespec="seconds") - for n in new_names: - src = name_to_path.get(n) - if src: - registry["custom_operators"][n] = { - "source_path": src, + all_new_names = sorted(after - before) + + # Persist only newly registered paths. + if valid_paths: + now = datetime.now().isoformat(timespec="seconds") + for abs_p in valid_paths: + path_type = "directory" if os.path.isdir(abs_p) else "file" + registry["registrations"][abs_p] = { + "type": path_type, "registered_at": now, } + _write_registry(registry) - _write_registry(registry) + if skipped: + for sp in skipped: + warnings.append(f"Path already registered: {sp}") + + return {"registered": all_new_names, "skipped": skipped, "warnings": warnings} - return {"registered": new_names, "warnings": warnings} +def unregister_paths(paths: List[str]) -> dict: + """Remove the given registration paths (and all their operators) from + the persistent registry and the current-process ``OPERATORS``. -def unregister_ops(names: List[str]) -> dict: - """Remove the given custom operators from the persistent registry and - the current-process ``OPERATORS``. + Paths must match exactly what was originally registered. For example, + if a directory was registered, only that directory path can be used to + unregister — individual files within it are not accepted. - Returns ``{"removed": [...], "not_found": [...]}``. + Returns ``{"removed": [...], "not_found": [...], "warnings": [...]}``. """ from data_juicer.ops.base_op import OPERATORS - registry = _read_registry() removed: List[str] = [] not_found: List[str] = [] + warnings: List[str] = [] - for name in names: - if name in registry["custom_operators"]: - meta = registry["custom_operators"][name] - del registry["custom_operators"][name] - removed.append(name) - - # Also remove from in-process registry. - OPERATORS.unregister_module(name) - - # Remove from sys.modules using the source file stem for - # precise matching — avoids accidentally removing built-in - # modules that happen to share a suffix. - src = meta.get("source_path") - if src: - module_name = os.path.splitext(os.path.basename(src))[0] - if module_name in sys.modules: - del sys.modules[module_name] + registry = _read_registry() + for p in paths: + abs_p = os.path.realpath(p) + if abs_p in registry["registrations"]: + del registry["registrations"][abs_p] + removed.append(abs_p) + + # Remove all operators whose class was defined under this path. + for name in _ops_for_path(abs_p): + OPERATORS.unregister_module(name) + + # Remove associated modules from sys.modules. + for mod_name in _collect_modules_for_path(abs_p): + sys.modules.pop(mod_name, None) else: - not_found.append(name) + not_found.append(abs_p) _write_registry(registry) - return {"removed": removed, "not_found": not_found} + + if not_found: + warnings.append( + "Only exact registration paths can be unregistered " + "(e.g. if a directory was registered, the entire directory " + "path must be used). Run 'list' to see all registered paths." + ) + + return {"removed": removed, "not_found": not_found, "warnings": warnings} def reset_registry() -> dict: @@ -292,21 +365,19 @@ def reset_registry() -> dict: from data_juicer.ops.base_op import OPERATORS registry = _read_registry() - removed = sorted(registry["custom_operators"].keys()) + registrations = registry.get("registrations", {}) + removed_paths = sorted(registrations.keys()) - custom_ops = registry["custom_operators"] - for name in removed: - OPERATORS.unregister_module(name) - meta = custom_ops.get(name) - if meta and "source_path" in meta: - module_name = os.path.splitext(os.path.basename(meta["source_path"]))[0] - if module_name in sys.modules: - del sys.modules[module_name] + for reg_path in removed_paths: + for name in _ops_for_path(reg_path): + OPERATORS.unregister_module(name) + for mod_name in _collect_modules_for_path(reg_path): + sys.modules.pop(mod_name, None) - registry["custom_operators"] = {} + registry["registrations"] = {} _write_registry(registry) - return {"removed": removed} + return {"removed": removed_paths} # --------------------------------------------------------------------------- @@ -315,12 +386,40 @@ def reset_registry() -> dict: def list_registered() -> dict: - """Return the contents of the persistent registry. + """Return the contents of the persistent registry with a live operator + view. - Returns ``{"custom_operators": {...}}``. + Returns a dict with two views: + + - ``"registrations"``: the path-keyed registry data, each entry + augmented with a live ``"operators"`` list derived from the + current-process ``OPERATORS``. + - ``"custom_operators"``: a flattened ``{op_name: {"source_path": ..., + "registered_at": ...}}`` dict for backward compatibility with + tooling that enumerates custom op names (e.g. ``op_search``). """ registry = _read_registry() - return {"custom_operators": registry.get("custom_operators", {})} + registrations = registry.get("registrations", {}) + + # Augment each registration with a live operator list. + augmented: Dict[str, dict] = {} + flat: Dict[str, dict] = {} + for reg_path, meta in registrations.items(): + live_ops = _ops_for_path(reg_path) + augmented[reg_path] = { + **meta, + "operators": live_ops, + } + for name in live_ops: + flat[name] = { + "source_path": reg_path, + "registered_at": meta.get("registered_at", ""), + } + + return { + "registrations": augmented, + "custom_operators": flat, + } # --------------------------------------------------------------------------- @@ -332,55 +431,54 @@ def load_persistent_custom_ops() -> dict: """Load all custom operators from the persistent registry into the current-process ``OPERATORS``. - Entries whose source files no longer exist are automatically removed + Entries whose source paths no longer exist are automatically removed from the registry and a warning is logged. Returns ``{"loaded": [...], "cleaned": [...], "warnings": [...]}``. """ registry = _read_registry() - custom_ops = registry.get("custom_operators", {}) + registrations = registry.get("registrations", {}) - if not custom_ops: + if not registrations: return {"loaded": [], "cleaned": [], "warnings": []} loaded: List[str] = [] cleaned: List[str] = [] warnings: List[str] = [] - # Validate paths first. + # Validate registration paths. valid_entries: Dict[str, dict] = {} - for name, meta in list(custom_ops.items()): - src = meta.get("source_path", "") - if not src or not os.path.exists(src): - logger.warning(f"Custom op '{name}' source not found at '{src}', " f"removing from registry.") - cleaned.append(name) - warnings.append(f"Cleaned stale entry '{name}': source '{src}' not found.") + for reg_path, meta in list(registrations.items()): + if not reg_path or not os.path.exists(reg_path): + logger.warning(f"Custom op registration path not found: '{reg_path}', " f"removing from registry.") + cleaned.append(reg_path) + warnings.append(f"Cleaned stale entry '{reg_path}': path not found.") else: - valid_entries[name] = meta + valid_entries[reg_path] = meta # If we cleaned anything, persist the updated registry. if cleaned: - registry["custom_operators"] = valid_entries + registry = _read_registry() + for cp in cleaned: + registry["registrations"].pop(cp, None) _write_registry(registry) - # Load valid entries. - # Group by source_path to avoid loading the same file twice. - paths_to_load: List[str] = [] - seen_paths: set = set() - for name, meta in valid_entries.items(): - src = meta["source_path"] - if src not in seen_paths: - paths_to_load.append(src) - seen_paths.add(src) - + # Load valid entries one-by-one so that a failure in one path does + # not prevent the remaining paths from loading. + paths_to_load = sorted(valid_entries.keys()) if paths_to_load: - try: - load_custom_operators(paths_to_load) - loaded = sorted(valid_entries.keys()) - except Exception as exc: - msg = f"Failed to load persistent custom ops: {exc}" - logger.error(msg) - warnings.append(msg) + from data_juicer.ops.base_op import OPERATORS + + before = set(OPERATORS.modules.keys()) + for path in paths_to_load: + try: + load_custom_operators([path]) + except Exception as exc: + msg = f"Failed to load custom op from '{path}': {exc}" + logger.error(msg) + warnings.append(msg) + after = set(OPERATORS.modules.keys()) + loaded = sorted(after - before) return {"loaded": loaded, "cleaned": cleaned, "warnings": warnings} @@ -417,12 +515,12 @@ def _build_parser() -> argparse.ArgumentParser: # --- unregister --- p_unreg = sub.add_parser( "unregister", - help="Unregister custom operator(s) from the persistent registry", + help="Unregister custom operator(s) by their registration path(s)", ) p_unreg.add_argument( - "names", + "paths", nargs="+", - help="Name(s) of custom operators to remove", + help="Registration path(s) to remove (file or directory)", ) # --- reset --- @@ -437,17 +535,20 @@ def _build_parser() -> argparse.ArgumentParser: def _cmd_list(args) -> int: """List registered custom operators.""" result = list_registered() - custom_ops = result.get("custom_operators", {}) - if not custom_ops: + registrations = result.get("registrations", {}) + if not registrations: print("No custom operators registered.") return 0 - print(f"Custom operators ({len(custom_ops)}):") - for name, meta in sorted(custom_ops.items()): - src = meta.get("source_path", "?") + total_ops = sum(len(m.get("operators", [])) for m in registrations.values()) + print(f"Custom operators ({total_ops} op(s) from {len(registrations)} path(s)):") + for reg_path, meta in sorted(registrations.items()): + path_type = meta.get("type", "?") reg_at = meta.get("registered_at", "?") - print(f" {name}") - print(f" source: {src}") + ops = meta.get("operators", []) + print(f" [{path_type}] {reg_path}") print(f" registered_at: {reg_at}") + ops_str = ", ".join(sorted(ops)) if ops else "(none loaded)" + print(f" operators: {ops_str}") return 0 @@ -455,13 +556,20 @@ def _cmd_register(args) -> int: """Register custom operator(s) persistently.""" result = register_persistent(args.paths) registered = result.get("registered", []) + skipped = result.get("skipped", []) warnings = result.get("warnings", []) if registered: print(f"Registered {len(registered)} operator(s):") for name in registered: print(f" + {name}") - else: + + if skipped: + print(f"Skipped {len(skipped)} already-registered path(s):") + for p in skipped: + print(f" ~ {p}") + + if not registered and not skipped: print("No new operators registered.") for warning in warnings: @@ -471,20 +579,24 @@ def _cmd_register(args) -> int: def _cmd_unregister(args) -> int: - """Unregister custom operator(s) from the persistent registry.""" - result = unregister_ops(args.names) + """Unregister custom operator(s) by their registration path(s).""" + result = unregister_paths(args.paths) removed = result.get("removed", []) not_found = result.get("not_found", []) + warnings = result.get("warnings", []) if removed: - print(f"Removed {len(removed)} operator(s):") - for name in removed: - print(f" - {name}") + print(f"Removed {len(removed)} registration(s):") + for p in removed: + print(f" - {p}") if not_found: print(f"Not found ({len(not_found)}):") - for name in not_found: - print(f" ? {name}") + for p in not_found: + print(f" ? {p}") + + for warning in warnings: + print(f" WARNING: {warning}", file=sys.stderr) return 0 @@ -495,9 +607,9 @@ def _cmd_reset(args) -> int: removed = result.get("removed", []) if removed: - print(f"Removed {len(removed)} custom operator(s):") - for name in removed: - print(f" - {name}") + print(f"Removed {len(removed)} registration(s):") + for p in removed: + print(f" - {p}") else: print("Registry was already empty.") return 0 diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index 793e732124..075f73f0a8 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -42,8 +42,8 @@ python -m data_juicer.utils.custom_op register /path/to/my_mapper.py # List all registered custom operators python -m data_juicer.utils.custom_op list -# Unregister by operator name -python -m data_juicer.utils.custom_op unregister my_mapper +# Unregister by registration path +python -m data_juicer.utils.custom_op unregister /path/to/my_mapper.py # Clear all custom operator registrations python -m data_juicer.utils.custom_op reset diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index 8533de1571..820951210d 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -42,8 +42,8 @@ python -m data_juicer.utils.custom_op register /path/to/my_mapper.py # 列出所有已注册的自定义算子 python -m data_juicer.utils.custom_op list -# 按算子名称取消注册 -python -m data_juicer.utils.custom_op unregister my_mapper +# 按注册路径取消注册 +python -m data_juicer.utils.custom_op unregister /path/to/my_mapper.py # 清除所有自定义算子注册 python -m data_juicer.utils.custom_op reset diff --git a/tests/utils/test_custom_op.py b/tests/utils/test_custom_op.py index 244369a5ad..9fec0c05e4 100644 --- a/tests/utils/test_custom_op.py +++ b/tests/utils/test_custom_op.py @@ -26,15 +26,16 @@ main as custom_op_main, register_persistent, reset_registry, - unregister_ops, + unregister_paths, ) from data_juicer.utils.registry import Registry from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + def _make_custom_op_file(tmp_dir: str, op_name: str = "my_test_mapper") -> str: """Create a minimal custom mapper .py file and return its path.""" code = textwrap.dedent(f"""\ - from data_juicer.ops.base_op import OPERATORS, Mapper + from data_juicer.ops import OPERATORS, Mapper @OPERATORS.register_module('{op_name}') class MyTestMapper(Mapper): @@ -50,10 +51,12 @@ def process_single(self, sample): f.write(code) return path + # --------------------------------------------------------------------------- # Unit tests for low-level API # --------------------------------------------------------------------------- + class RegistryUnregisterTest(DataJuicerTestCaseBase): """Test Registry.unregister_module (new method).""" @@ -74,6 +77,7 @@ def test_unregister_nonexistent(self): result = reg.unregister_module("no_such_op") self.assertFalse(result) + class RegistryPathTest(DataJuicerTestCaseBase): """Test get_registry_path with and without env override.""" @@ -96,6 +100,7 @@ def test_env_override(self): finally: del os.environ["DJ_CUSTOM_OP_REGISTRY"] + class ReadWriteRegistryTest(DataJuicerTestCaseBase): """Test _read_registry / _write_registry helpers.""" @@ -110,14 +115,17 @@ def tearDown(self): def test_read_empty(self): data = _read_registry() - self.assertEqual(data["version"], 1) - self.assertEqual(data["custom_operators"], {}) + self.assertEqual(data["version"], 2) + self.assertEqual(data["registrations"], {}) def test_write_and_read(self): payload = { - "version": 1, - "custom_operators": { - "foo": {"source_path": "/tmp/foo.py", "registered_at": "2026-01-01T00:00:00"} + "version": 2, + "registrations": { + "/tmp/foo.py": { + "type": "file", + "registered_at": "2026-01-01T00:00:00", + } }, } _write_registry(payload) @@ -128,7 +136,8 @@ def test_read_malformed(self): with open(self._reg_path, "w") as f: f.write("not json") data = _read_registry() - self.assertEqual(data["custom_operators"], {}) + self.assertEqual(data["registrations"], {}) + class RegisterPersistentTest(DataJuicerTestCaseBase): """Test register_persistent end-to-end.""" @@ -151,20 +160,24 @@ def test_register_and_json(self): self.assertIn(self._op_name, result["registered"]) self.assertIn(self._op_name, OPERATORS.modules) + # Registry should store the path, not operator names. data = _read_registry() - self.assertIn(self._op_name, data["custom_operators"]) - self.assertEqual( - data["custom_operators"][self._op_name]["source_path"], - os.path.abspath(self._op_file), - ) + abs_file = os.path.abspath(self._op_file) + self.assertIn(abs_file, data["registrations"]) + meta = data["registrations"][abs_file] + self.assertEqual(meta["type"], "file") + self.assertIn("registered_at", meta) + # No "operators" key in the persisted data. + self.assertNotIn("operators", meta) def test_register_nonexistent_path(self): result = register_persistent(["/no/such/path.py"]) self.assertEqual(result["registered"], []) self.assertGreater(len(result["warnings"]), 0) -class UnregisterOpsTest(DataJuicerTestCaseBase): - """Test unregister_ops.""" + +class UnregisterPathsTest(DataJuicerTestCaseBase): + """Test unregister_paths.""" def setUp(self): self._tmp = tempfile.TemporaryDirectory() @@ -181,18 +194,20 @@ def tearDown(self): self._tmp.cleanup() def test_unregister_existing(self): - result = unregister_ops([self._op_name]) - self.assertIn(self._op_name, result["removed"]) + abs_file = os.path.abspath(self._op_file) + result = unregister_paths([abs_file]) + self.assertIn(abs_file, result["removed"]) self.assertEqual(result["not_found"], []) self.assertNotIn(self._op_name, OPERATORS.modules) data = _read_registry() - self.assertNotIn(self._op_name, data["custom_operators"]) + self.assertNotIn(abs_file, data["registrations"]) def test_unregister_nonexistent(self): - result = unregister_ops(["no_such_op"]) + result = unregister_paths(["/no/such/path.py"]) self.assertEqual(result["removed"], []) - self.assertIn("no_such_op", result["not_found"]) + self.assertIn("/no/such/path.py", result["not_found"]) + class ResetRegistryTest(DataJuicerTestCaseBase): """Test reset_registry.""" @@ -212,12 +227,14 @@ def tearDown(self): self._tmp.cleanup() def test_reset_clears_all(self): + abs_file = os.path.abspath(self._op_file) result = reset_registry() - self.assertIn(self._op_name, result["removed"]) + self.assertIn(abs_file, result["removed"]) self.assertNotIn(self._op_name, OPERATORS.modules) data = _read_registry() - self.assertEqual(data["custom_operators"], {}) + self.assertEqual(data["registrations"], {}) + class ListRegisteredTest(DataJuicerTestCaseBase): """Test list_registered.""" @@ -233,6 +250,7 @@ def tearDown(self): def test_list_empty(self): result = list_registered() + self.assertEqual(result["registrations"], {}) self.assertEqual(result["custom_operators"], {}) def test_list_after_register(self): @@ -241,11 +259,18 @@ def test_list_after_register(self): register_persistent([op_file]) try: result = list_registered() + # Check the flattened view (live from OPERATORS) self.assertIn(op_name, result["custom_operators"]) + # Check the path-keyed view + abs_file = os.path.abspath(op_file) + self.assertIn(abs_file, result["registrations"]) + # The live "operators" list should contain the op + self.assertIn(op_name, result["registrations"][abs_file]["operators"]) finally: OPERATORS.unregister_module(op_name) sys.modules.pop(op_name, None) + class LoadPersistentCustomOpsTest(DataJuicerTestCaseBase): """Test load_persistent_custom_ops including stale-entry cleanup.""" @@ -276,10 +301,10 @@ def test_load_valid(self): def test_load_cleans_stale(self): payload = { - "version": 1, - "custom_operators": { - "stale_op": { - "source_path": "/no/such/file.py", + "version": 2, + "registrations": { + "/no/such/file.py": { + "type": "file", "registered_at": "2026-01-01T00:00:00", } }, @@ -287,11 +312,12 @@ def test_load_cleans_stale(self): _write_registry(payload) result = load_persistent_custom_ops() - self.assertIn("stale_op", result["cleaned"]) + self.assertIn("/no/such/file.py", result["cleaned"]) self.assertGreater(len(result["warnings"]), 0) data = _read_registry() - self.assertNotIn("stale_op", data["custom_operators"]) + self.assertNotIn("/no/such/file.py", data["registrations"]) + class CrossProcessVisibilityTest(DataJuicerTestCaseBase): """Test that a custom op registered in one process is visible in another.""" @@ -333,10 +359,12 @@ def test_cross_process(self): self.assertEqual(proc.returncode, 0, f"stderr: {proc.stderr}") self.assertIn("OK", proc.stdout) + # --------------------------------------------------------------------------- # CLI tests # --------------------------------------------------------------------------- + class CustomOpCLIListTest(DataJuicerTestCaseBase): """Test the 'list' CLI sub-command.""" @@ -353,6 +381,7 @@ def test_list_empty(self): rc = custom_op_main(["list"]) self.assertEqual(rc, 0) + class CustomOpCLIRegisterTest(DataJuicerTestCaseBase): """Test the 'register' CLI sub-command.""" @@ -374,6 +403,7 @@ def test_register(self): self.assertEqual(rc, 0) self.assertIn(self._op_name, OPERATORS.modules) + class CustomOpCLIUnregisterTest(DataJuicerTestCaseBase): """Test the 'unregister' CLI sub-command.""" @@ -392,10 +422,12 @@ def tearDown(self): self._tmp.cleanup() def test_unregister(self): - rc = custom_op_main(["unregister", self._op_name]) + abs_file = os.path.abspath(self._op_file) + rc = custom_op_main(["unregister", abs_file]) self.assertEqual(rc, 0) self.assertNotIn(self._op_name, OPERATORS.modules) + class CustomOpCLIResetTest(DataJuicerTestCaseBase): """Test the 'reset' CLI sub-command.""" @@ -418,6 +450,7 @@ def test_reset(self): self.assertEqual(rc, 0) self.assertNotIn(self._op_name, OPERATORS.modules) + class CustomOpCLINoCommandTest(DataJuicerTestCaseBase): """Test calling custom_op CLI with no command.""" @@ -425,5 +458,6 @@ def test_no_command(self): rc = custom_op_main([]) self.assertEqual(rc, 1) + if __name__ == "__main__": unittest.main()