diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index d39b11e947..8e4c5549c2 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -1,6 +1,5 @@ import argparse import copy -import importlib.util import json import os import shutil @@ -45,57 +44,15 @@ def timing_context(description): logger.debug(f"{description} took {elapsed_time:.2f} seconds") -def _generate_module_name(abs_path): - """Generate a module name based on the absolute path of the file.""" - return os.path.splitext(os.path.basename(abs_path))[0] +def load_custom_operators(paths): + """Dynamically load custom operator modules or packages in the specified path. + This is a re-export from ``data_juicer.utils.custom_op`` kept here for + backward compatibility. + """ + from data_juicer.utils.custom_op import load_custom_operators as _impl -def load_custom_operators(paths): - """Dynamically load custom operator modules or packages in the specified path.""" - for path in paths: - abs_path = os.path.abspath(path) - if os.path.isfile(abs_path): - module_name = _generate_module_name(abs_path) - if module_name in sys.modules: - existing_path = sys.modules[module_name].__file__ - raise RuntimeError( - f"Module '{module_name}' already loaded from '{existing_path}'. " - f"Conflict detected while loading '{abs_path}'." - ) - try: - spec = importlib.util.spec_from_file_location(module_name, abs_path) - if spec is None: - raise RuntimeError(f"Failed to create spec for '{abs_path}'") - module = importlib.util.module_from_spec(spec) - # register the module first to avoid recursive import issues - sys.modules[module_name] = module - spec.loader.exec_module(module) - except Exception as e: - raise RuntimeError(f"Error loading '{abs_path}' as '{module_name}': {e}") - - elif os.path.isdir(abs_path): - if not os.path.isfile(os.path.join(abs_path, "__init__.py")): - raise ValueError(f"Package directory '{abs_path}' must contain __init__.py") - package_name = os.path.basename(abs_path) - parent_dir = os.path.dirname(abs_path) - if package_name in sys.modules: - existing_path = sys.modules[package_name].__path__[0] - raise RuntimeError( - f"Package '{package_name}' already loaded from '{existing_path}'. " - f"Conflict detected while loading '{abs_path}'." - ) - original_sys_path = sys.path.copy() - try: - sys.path.insert(0, parent_dir) - importlib.import_module(package_name) - # record the loading path of the package (for subsequent conflict detection) - sys.modules[package_name].__loaded_from__ = abs_path - except Exception as e: - raise RuntimeError(f"Error loading package '{abs_path}': {e}") - finally: - sys.path = original_sys_path - else: - raise ValueError(f"Path '{abs_path}' is neither a file nor a directory") + _impl(paths) def build_base_parser() -> ArgumentParser: diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py index 15b2b7bda8..373d0e86ae 100644 --- a/data_juicer/ops/__init__.py +++ b/data_juicer/ops/__init__.py @@ -14,6 +14,10 @@ def timing_context(description): # yapf: disable with timing_context('Importing operator modules'): + # 1. Built-in operators (registered via @OPERATORS.register_module decorators + # that fire as each sub-package is imported) + # 2. Persistent custom operators (loaded from ~/.data_juicer/custom_op.json; + # no-op when the registry file does not exist) from . import aggregator, deduplicator, filter, grouper, mapper, pipeline, selector from .base_op import ( ATTRIBUTION_FILTERS, @@ -38,22 +42,26 @@ def timing_context(description): op_requirements_to_op_env_spec, ) + from data_juicer.utils.custom_op import load_persistent_custom_ops as _load_persistent # isort: skip # noqa: E501 + _load_persistent() + del _load_persistent + __all__ = [ - 'load_ops', - 'Filter', - 'Mapper', - 'Deduplicator', - 'Selector', - 'Grouper', - 'Aggregator', - 'UNFORKABLE', - 'NON_STATS_FILTERS', - 'OPERATORS', - 'TAGGING_OPS', - 'Pipeline', - 'OPEnvSpec', - 'op_requirements_to_op_env_spec', - 'OPEnvManager', - 'analyze_lazy_loaded_requirements', - 'analyze_lazy_loaded_requirements_for_code_file', + "load_ops", + "Filter", + "Mapper", + "Deduplicator", + "Selector", + "Grouper", + "Aggregator", + "UNFORKABLE", + "NON_STATS_FILTERS", + "OPERATORS", + "TAGGING_OPS", + "Pipeline", + "OPEnvSpec", + "op_requirements_to_op_env_spec", + "OPEnvManager", + "analyze_lazy_loaded_requirements", + "analyze_lazy_loaded_requirements_for_code_file", ] diff --git a/data_juicer/tools/DJ_mcp_granular_ops.py b/data_juicer/tools/DJ_mcp_granular_ops.py index 1ebd897e40..89cde83508 100644 --- a/data_juicer/tools/DJ_mcp_granular_ops.py +++ b/data_juicer/tools/DJ_mcp_granular_ops.py @@ -2,7 +2,7 @@ import inspect import os import sys -from typing import Annotated, Optional +from typing import Annotated, Optional, get_type_hints from pydantic import Field @@ -13,6 +13,30 @@ fastmcp = LazyLoader("mcp.server.fastmcp", "mcp[cli]") +def resolve_signature_annotations(func, sig: inspect.Signature) -> inspect.Signature: + """Resolve postponed/string annotations into real runtime types. + + When a module uses ``from __future__ import annotations``, all + annotations are stored as strings. This helper calls + ``typing.get_type_hints`` on the original callable to obtain the + real type objects and rebuilds the signature with them. + """ + try: + module = sys.modules.get(func.__module__, None) if hasattr(func, "__module__") else None + globalns = module.__dict__ if module else {} + hints = get_type_hints(func, globalns=globalns) + except Exception: + hints = {} + + new_params = [] + for name, param in sig.parameters.items(): + resolved_annotation = hints.get(name, param.annotation) + new_params.append(param.replace(annotation=resolved_annotation)) + + return_annotation = hints.get("return", sig.return_annotation) + return sig.replace(parameters=new_params, return_annotation=return_annotation) + + # Dynamic MCP Tool Creation def process_parameter(name: str, param: inspect.Parameter) -> inspect.Parameter: """ @@ -31,13 +55,18 @@ def create_operator_function(op, mcp): This function dynamically creates a function that can be registered as an MCP tool, with proper signature and documentation based on the operator's __init__ method. """ - sig = op["sig"] + raw_sig = op["sig"] + init_func = op.get("init_func") + if init_func is not None: + sig = resolve_signature_annotations(init_func, raw_sig) + else: + sig = raw_sig docstring = op["desc"] param_docstring = op["param_desc"] # Create new function signature with dataset_path as first parameter - # Consider adding other common parameters later, such as export_psth - new_parameters = [ + # Consider adding other common parameters later, such as export_path + fixed_params = [ inspect.Parameter("dataset_path", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str), inspect.Parameter( "export_path", @@ -51,11 +80,18 @@ def create_operator_function(op, mcp): annotation=Optional[int], default=None, ), - ] + [ + ] + op_params = [ process_parameter(name, param) for name, param in sig.parameters.items() if name not in ("args", "kwargs", "self") ] + # Merge all params, then reorder: required (no default) first, + # optional (with default) second, to satisfy Python's signature rule. + all_params = fixed_params + op_params + required_params = [p for p in all_params if p.default is inspect.Parameter.empty] + optional_params = [p for p in all_params if p.default is not inspect.Parameter.empty] + new_parameters = required_params + optional_params new_signature = sig.replace(parameters=new_parameters, return_annotation=str) def func(*args, **kwargs): @@ -66,7 +102,7 @@ def func(*args, **kwargs): export_path = bound_arguments.arguments.pop("export_path") dataset_path = bound_arguments.arguments.pop("dataset_path") np = bound_arguments.arguments.pop("np") - args_dict = {k: v for k, v in bound_arguments.arguments.items() if v} + args_dict = {k: v for k, v in bound_arguments.arguments.items() if v is not None} dj_cfg = { "dataset_path": dataset_path, diff --git a/data_juicer/tools/op_search.py b/data_juicer/tools/op_search.py index e7a1f79a35..49a224b613 100644 --- a/data_juicer/tools/op_search.py +++ b/data_juicer/tools/op_search.py @@ -159,22 +159,46 @@ class OPRecord: def __init__(self, name: str, op_cls: type, op_type: Optional[str] = None): self.name = name - self.type = op_type or op_cls.__module__.split(".")[2].lower() + + # --- module path: + # handling for custom ops --- + if op_type: + self.type = op_type + else: + module_parts = op_cls.__module__.split(".") + if len(module_parts) >= 3: + self.type = module_parts[2].lower() + else: + self.type = self._search_mro_for_type(op_cls) if self.type not in op_type_list: self.type = self._search_mro_for_type(op_cls) + self.desc = op_cls.__doc__ or "" self.tags = analyze_tag_from_cls(op_cls, name) self.sig = inspect.signature(op_cls.__init__) + self.init_func = op_cls.__init__ self.param_desc = extract_param_docstring(op_cls.__init__.__doc__ or "") self.param_desc_map = self._parse_param_desc() - self.source_path = str(get_source_path(op_cls)) - self.test_path = None - - test_path = f"tests/ops/{self.type}/test_{self.name}.py" - if not (PROJECT_ROOT / test_path).exists(): - test_path = find_test_by_searching_content(PROJECT_ROOT / "tests", op_cls.__name__ + "Test") or test_path - self.test_path = str(test_path) + # --- source path: handling for custom ops --- + try: + self.source_path = str(get_source_path(op_cls)) + except (ValueError, TypeError, OSError): + try: + self.source_path = str(Path(inspect.getfile(op_cls))) + except (TypeError, OSError): + self.source_path = "unknown" + + # --- test path: handling for custom ops --- + try: + test_path = f"tests/ops/{self.type}/test_{self.name}.py" + if not (PROJECT_ROOT / test_path).exists(): + test_path = ( + find_test_by_searching_content(PROJECT_ROOT / "tests", op_cls.__name__ + "Test") or test_path + ) + self.test_path = str(test_path) + except Exception: + self.test_path = None def __getitem__(self, item): try: @@ -209,6 +233,7 @@ def to_dict(self): "desc": self.desc, "tags": self.tags, "sig": self.sig, + "init_func": self.init_func, "param_desc": self.param_desc, "param_desc_map": self.param_desc_map, "source_path": self.source_path, @@ -441,26 +466,176 @@ def records_map(self): return self.all_ops -def main(query, tags, op_type): +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def _build_parser(): + import argparse + + parser = argparse.ArgumentParser( + prog="python -m data_juicer.tools.op_search", + description="Data-Juicer Operator Search & Query Tool", + ) + sub = parser.add_subparsers(dest="command", help="Available commands") + + # --- list --- + sub.add_parser( + "list", + help="List all operators (built-in + custom)", + ) + + # --- info --- + p_info = sub.add_parser( + "info", + help="Show detailed information about an operator", + ) + p_info.add_argument("name", help="Operator name") + + # --- search --- + p_search = sub.add_parser( + "search", + help="Search operators by keyword, regex, or tags", + ) + p_search.add_argument( + "query", + nargs="?", + default=None, + help="Search query (natural language or regex pattern)", + ) + p_search.add_argument( + "--mode", + choices=["bm25", "regex"], + default="bm25", + help="Search mode (default: bm25)", + ) + p_search.add_argument( + "--tags", + nargs="+", + default=None, + help="Filter by tags (e.g., gpu, cpu, text, image)", + ) + p_search.add_argument( + "--type", + dest="op_type", + default=None, + help="Filter by operator type (e.g., mapper, filter)", + ) + p_search.add_argument( + "--top-k", + type=int, + default=10, + help="Maximum number of results (default: 10)", + ) + + return parser + + +def _cmd_list(args) -> int: + """List all operators (built-in + custom).""" + from data_juicer.utils.custom_op import list_registered + + custom_info = list_registered() + custom_names = set(custom_info.get("custom_operators", {}).keys()) + all_names = sorted(OPERATORS.modules.keys()) + + print(f"Total operators: {len(all_names)}") + print(f" Built-in: {len(all_names) - len(custom_names)}") + print(f" Custom: {len(custom_names)}") + print() + for name in all_names: + marker = " [custom]" if name in custom_names else "" + print(f" {name}{marker}") + return 0 + + +def _cmd_info(args) -> int: + """Show detailed information about an operator.""" + import sys + + op_cls = OPERATORS.modules.get(args.name) + if op_cls is None: + print(f"Operator '{args.name}' not found.", file=sys.stderr) + return 1 + + record = OPRecord(name=args.name, op_cls=op_cls) + info = record.to_dict() + + print(f"Name: {info['name']}") + print(f"Type: {info['type']}") + print(f"Tags: {', '.join(info['tags']) if info['tags'] else 'none'}") + print(f"Source: {info['source_path']}") + print(f"Test: {info['test_path'] or 'none'}") + print(f"Signature: {info['sig']}") + print() + if info["desc"]: + print("Description:") + print(f" {info['desc'].strip()}") + print() + if info["param_desc_map"]: + print("Parameters:") + for pname, pdesc in info["param_desc_map"].items(): + print(f" {pname}: {pdesc}") + + return 0 + + +def _cmd_search(args) -> int: + """Search operators by keyword, regex, or tags.""" searcher = OPSearcher(include_formatter=True) - results = searcher.search_by_bm25(query=query, tags=tags, op_type=op_type) + query = args.query + tags = args.tags + op_type = args.op_type + + if args.mode == "regex" and query: + results = searcher.search_by_regex(query=query, tags=tags, op_type=op_type) + elif query: + results = searcher.search_by_bm25(query=query, tags=tags, op_type=op_type, top_k=args.top_k) + else: + results = searcher.search(tags=tags, op_type=op_type) - print(f"\nFound {len(results)} operators:") + print(f"Found {len(results)} operator(s):") for op in results: - print(f"\n[{op['type'].upper()}] {op['name']}") - print(f"Tags: {', '.join(op['tags'])}") - print(f"Description: {op['desc']}") - print(f"Parameters: {op['param_desc']}") - print(f"Parameter Descriptions: {op['param_desc_map']}") - print(f"Signature: {op['sig']}") - print("-" * 50) + print(f"\n [{op['type'].upper()}] {op['name']}") + print(f" Tags: {', '.join(op['tags'])}") + desc = (op.get("desc") or "").strip() + if desc: + first_line = desc.split("\n")[0].strip() + if len(first_line) > 80: + first_line = first_line[:77] + "..." + print(f" Desc: {first_line}") - print(searcher.records_map["nlpaug_en_mapper"]["source_path"]) - print(searcher.records_map["nlpaug_en_mapper"].test_path) + return 0 + + +_COMMAND_MAP = { + "list": _cmd_list, + "info": _cmd_info, + "search": _cmd_search, +} + + +def main(argv=None) -> int: + """CLI entry point for operator search & query.""" + + parser = _build_parser() + args = parser.parse_args(argv) + + if not args.command: + parser.print_help() + return 1 + + handler = _COMMAND_MAP.get(args.command) + if handler is None: + parser.print_help() + return 1 + + return handler(args) if __name__ == "__main__": - tags = [] - op_type = "formatter" - main(query="json", tags=tags, op_type=op_type) + import sys + + sys.exit(main()) diff --git a/data_juicer/utils/custom_op.py b/data_juicer/utils/custom_op.py new file mode 100644 index 0000000000..0d1771d354 --- /dev/null +++ b/data_juicer/utils/custom_op.py @@ -0,0 +1,644 @@ +""" +Custom Operator Management +=========================== + +Manages a persistent JSON registry at ``~/.data_juicer/custom_op.json`` +so that custom operators survive across processes. + +The registry stores only **registration paths** (file or directory) as +primary keys. It does **not** cache operator names — the actual operator +list is always derived at runtime from the in-process ``OPERATORS`` +registry after loading, so it is always up-to-date with the file system. + +On startup, ``load_persistent_custom_ops()`` replays those registrations +into the in-process ``OPERATORS`` registry, automatically cleaning up +entries whose source paths no longer exist. + +Also provides a CLI for managing custom operators:: + + python -m data_juicer.utils.custom_op list + python -m data_juicer.utils.custom_op register /path/to/my_mapper.py + python -m data_juicer.utils.custom_op unregister /path/to/my_mapper.py + python -m data_juicer.utils.custom_op reset +""" + +import argparse +import importlib.util +import json +import os +import sys +import tempfile +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + +from loguru import logger + +# --------------------------------------------------------------------------- +# Dynamic module loading +# --------------------------------------------------------------------------- + + +def _generate_module_name(abs_path): + """Generate a module name based on the absolute path of the file.""" + return os.path.splitext(os.path.basename(abs_path))[0] + + +def _rollback_operators(new_names): + """Remove the given operator names from OPERATORS. + + Used to roll back a partially-successful load: *new_names* is the set + of operators that were registered after the snapshot but before the + error occurred. + """ + try: + from data_juicer.ops.base_op import OPERATORS + except ImportError: + return + for name in new_names: + OPERATORS.unregister_module(name) + + +def load_custom_operators(paths): + """Dynamically load custom operator modules or packages in the specified path.""" + for path in paths: + abs_path = os.path.realpath(path) + if os.path.isfile(abs_path): + module_name = _generate_module_name(abs_path) + if module_name in sys.modules: + existing_path = sys.modules[module_name].__file__ + raise RuntimeError( + f"Module '{module_name}' already loaded from '{existing_path}'. " + f"Conflict detected while loading '{abs_path}'." + ) + from data_juicer.ops.base_op import OPERATORS + + ops_before = set(OPERATORS.modules.keys()) + try: + spec = importlib.util.spec_from_file_location(module_name, abs_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Failed to create spec for '{abs_path}'") + module = importlib.util.module_from_spec(spec) + # register the module first to avoid recursive import issues + sys.modules[module_name] = module + spec.loader.exec_module(module) + except Exception as e: + # Clean up partially-initialized module to avoid stale entries + sys.modules.pop(module_name, None) + ops_after = set(OPERATORS.modules.keys()) + _rollback_operators(ops_after - ops_before) + raise RuntimeError(f"Error loading '{abs_path}' as '{module_name}': {e}") + + elif os.path.isdir(abs_path): + if not os.path.isfile(os.path.join(abs_path, "__init__.py")): + raise ValueError(f"Package directory '{abs_path}' must contain __init__.py") + package_name = os.path.basename(abs_path) + parent_dir = os.path.dirname(abs_path) + if package_name in sys.modules: + existing_path = sys.modules[package_name].__path__[0] + raise RuntimeError( + f"Package '{package_name}' already loaded from '{existing_path}'. " + f"Conflict detected while loading '{abs_path}'." + ) + from data_juicer.ops.base_op import OPERATORS + + ops_before = set(OPERATORS.modules.keys()) + original_sys_path = sys.path.copy() + try: + sys.path.insert(0, parent_dir) + importlib.import_module(package_name) + # record the loading path of the package (custom attribute) + setattr(sys.modules[package_name], "__loaded_from__", abs_path) + except Exception as e: + ops_after = set(OPERATORS.modules.keys()) + _rollback_operators(ops_after - ops_before) + raise RuntimeError(f"Error loading package '{abs_path}': {e}") + finally: + sys.path = original_sys_path + else: + raise ValueError(f"Path '{abs_path}' is neither a file nor a directory") + + +# --------------------------------------------------------------------------- +# Path management +# --------------------------------------------------------------------------- + + +def get_registry_path() -> Path: + """Return the path to the persistent op registry JSON file. + + Defaults to ``~/.data_juicer/custom_op.json``. + Override with the ``DJ_CUSTOM_OP_REGISTRY`` environment variable. + """ + override = os.environ.get("DJ_CUSTOM_OP_REGISTRY") + if override: + return Path(override) + return Path.home() / ".data_juicer" / "custom_op.json" + + +# --------------------------------------------------------------------------- +# Low-level read / write helpers +# --------------------------------------------------------------------------- +def _empty_registry() -> dict: + """Return a fresh empty registry structure.""" + return {"version": 2, "registrations": {}} + + +def _read_registry() -> dict: + """Read the JSON registry. Returns the empty structure when the file + does not exist or is malformed. + """ + path = get_registry_path() + if not path.exists(): + return _empty_registry() + try: + with open(path, "r", encoding="utf-8") as fh: + data = json.load(fh) + if not isinstance(data, dict) or "registrations" not in data: + logger.warning(f"Malformed op registry at {path}, resetting to empty.") + return _empty_registry() + return data + except (json.JSONDecodeError, OSError) as exc: + logger.warning(f"Failed to read op registry at {path}: {exc}") + return _empty_registry() + + +def _write_registry(data: dict) -> None: + """Atomically write *data* to the registry JSON file. + + Uses write-to-tmp + ``os.replace`` to avoid partial writes. + """ + path = get_registry_path() + path.parent.mkdir(parents=True, exist_ok=True) + + # Write to a temp file in the same directory, then atomically replace. + fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp", prefix=".op_registry_") + try: + with os.fdopen(fd, "w", encoding="utf-8") as fh: + json.dump(data, fh, indent=2, ensure_ascii=False) + fh.write("\n") + os.replace(tmp_path, str(path)) + except BaseException: + # Clean up the temp file on failure. + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _collect_modules_for_path(abs_path: str) -> List[str]: + """Return the sys.modules keys that were loaded from *abs_path*. + + For a single file, this is the module whose ``__file__`` matches. + For a directory/package, this includes all modules whose ``__file__`` + lives under *abs_path*. + """ + result = [] + for mod_name, mod in list(sys.modules.items()): + mod_file = getattr(mod, "__file__", None) + if mod_file is None: + continue + try: + mod_file = os.path.realpath(mod_file) + except (TypeError, OSError): + continue + if os.path.isfile(abs_path): + if mod_file == abs_path: + result.append(mod_name) + elif os.path.isdir(abs_path): + if mod_file.startswith(abs_path + os.sep) or mod_file == abs_path: + result.append(mod_name) + return result + + +def _ops_for_path(abs_path: str) -> List[str]: + """Return the OPERATORS names whose class was defined under *abs_path*. + + Inspects each registered operator class to find its source file and + checks whether it matches *abs_path* (file) or lives under it + (directory). Only called during unregister/list — not on the hot path. + """ + import inspect + + try: + from data_juicer.ops.base_op import OPERATORS + except ImportError: + return [] + + abs_path = os.path.realpath(abs_path) + result = [] + for name, cls in list(OPERATORS.modules.items()): + try: + cls_file = os.path.realpath(inspect.getfile(cls)) + except (TypeError, OSError): + continue + if os.path.isfile(abs_path): + if cls_file == abs_path: + result.append(name) + elif os.path.isdir(abs_path): + if cls_file.startswith(abs_path + os.sep): + result.append(name) + return sorted(result) + + +# --------------------------------------------------------------------------- +# Custom Op management +# --------------------------------------------------------------------------- + + +def register_persistent(paths: List[str]) -> dict: + """Register custom operators to the persistent registry **and** the + current-process ``OPERATORS``. + + *paths* is a list of file / directory paths (same semantics as + ``load_custom_operators``). Each path becomes a top-level key in the + registry so that unregister and reload operate at the same granularity. + + If a path is already registered it is skipped (idempotent). + + Returns ``{"registered": [...], "skipped": [...], "warnings": [...]}``. + """ + from data_juicer.ops.base_op import OPERATORS + + warnings: List[str] = [] + skipped: List[str] = [] + valid_paths: List[str] = [] + + registry = _read_registry() + existing_paths = set(registry.get("registrations", {}).keys()) + + for p in paths: + abs_p = os.path.realpath(p) + if not os.path.exists(abs_p): + warnings.append(f"Path does not exist: {abs_p}") + continue + if abs_p in existing_paths: + skipped.append(abs_p) + continue + valid_paths.append(abs_p) + + # Snapshot operator names before loading. + before = set(OPERATORS.modules.keys()) + + if valid_paths: + load_custom_operators(valid_paths) + + # Diff to find newly registered names. + after = set(OPERATORS.modules.keys()) + all_new_names = sorted(after - before) + + # Persist only newly registered paths. + if valid_paths: + now = datetime.now().isoformat(timespec="seconds") + for abs_p in valid_paths: + path_type = "directory" if os.path.isdir(abs_p) else "file" + registry["registrations"][abs_p] = { + "type": path_type, + "registered_at": now, + } + _write_registry(registry) + + if skipped: + for sp in skipped: + warnings.append(f"Path already registered: {sp}") + + return {"registered": all_new_names, "skipped": skipped, "warnings": warnings} + + +def unregister_paths(paths: List[str]) -> dict: + """Remove the given registration paths (and all their operators) from + the persistent registry and the current-process ``OPERATORS``. + + Paths must match exactly what was originally registered. For example, + if a directory was registered, only that directory path can be used to + unregister — individual files within it are not accepted. + + Returns ``{"removed": [...], "not_found": [...], "warnings": [...]}``. + """ + from data_juicer.ops.base_op import OPERATORS + + removed: List[str] = [] + not_found: List[str] = [] + warnings: List[str] = [] + + registry = _read_registry() + for p in paths: + abs_p = os.path.realpath(p) + if abs_p in registry["registrations"]: + del registry["registrations"][abs_p] + removed.append(abs_p) + + # Remove all operators whose class was defined under this path. + for name in _ops_for_path(abs_p): + OPERATORS.unregister_module(name) + + # Remove associated modules from sys.modules. + for mod_name in _collect_modules_for_path(abs_p): + sys.modules.pop(mod_name, None) + else: + not_found.append(abs_p) + + _write_registry(registry) + + if not_found: + warnings.append( + "Only exact registration paths can be unregistered " + "(e.g. if a directory was registered, the entire directory " + "path must be used). Run 'list' to see all registered paths." + ) + + return {"removed": removed, "not_found": not_found, "warnings": warnings} + + +def reset_registry() -> dict: + """Clear **all** custom operators from the persistent registry and the + current-process ``OPERATORS``. + + Returns ``{"removed": [...]}``. + """ + from data_juicer.ops.base_op import OPERATORS + + registry = _read_registry() + registrations = registry.get("registrations", {}) + removed_paths = sorted(registrations.keys()) + + for reg_path in removed_paths: + for name in _ops_for_path(reg_path): + OPERATORS.unregister_module(name) + for mod_name in _collect_modules_for_path(reg_path): + sys.modules.pop(mod_name, None) + + registry["registrations"] = {} + _write_registry(registry) + + return {"removed": removed_paths} + + +# --------------------------------------------------------------------------- +# Query +# --------------------------------------------------------------------------- + + +def list_registered() -> dict: + """Return the contents of the persistent registry with a live operator + view. + + Returns a dict with two views: + + - ``"registrations"``: the path-keyed registry data, each entry + augmented with a live ``"operators"`` list derived from the + current-process ``OPERATORS``. + - ``"custom_operators"``: a flattened ``{op_name: {"source_path": ..., + "registered_at": ...}}`` dict for backward compatibility with + tooling that enumerates custom op names (e.g. ``op_search``). + """ + registry = _read_registry() + registrations = registry.get("registrations", {}) + + # Augment each registration with a live operator list. + augmented: Dict[str, dict] = {} + flat: Dict[str, dict] = {} + for reg_path, meta in registrations.items(): + live_ops = _ops_for_path(reg_path) + augmented[reg_path] = { + **meta, + "operators": live_ops, + } + for name in live_ops: + flat[name] = { + "source_path": reg_path, + "registered_at": meta.get("registered_at", ""), + } + + return { + "registrations": augmented, + "custom_operators": flat, + } + + +# --------------------------------------------------------------------------- +# Startup loading +# --------------------------------------------------------------------------- + + +def load_persistent_custom_ops() -> dict: + """Load all custom operators from the persistent registry into the + current-process ``OPERATORS``. + + Entries whose source paths no longer exist are automatically removed + from the registry and a warning is logged. + + Returns ``{"loaded": [...], "cleaned": [...], "warnings": [...]}``. + """ + registry = _read_registry() + registrations = registry.get("registrations", {}) + + if not registrations: + return {"loaded": [], "cleaned": [], "warnings": []} + + loaded: List[str] = [] + cleaned: List[str] = [] + warnings: List[str] = [] + + # Validate registration paths. + valid_entries: Dict[str, dict] = {} + for reg_path, meta in list(registrations.items()): + if not reg_path or not os.path.exists(reg_path): + logger.warning(f"Custom op registration path not found: '{reg_path}', " f"removing from registry.") + cleaned.append(reg_path) + warnings.append(f"Cleaned stale entry '{reg_path}': path not found.") + else: + valid_entries[reg_path] = meta + + # If we cleaned anything, persist the updated registry. + if cleaned: + registry = _read_registry() + for cp in cleaned: + registry["registrations"].pop(cp, None) + _write_registry(registry) + + # Load valid entries one-by-one so that a failure in one path does + # not prevent the remaining paths from loading. + paths_to_load = sorted(valid_entries.keys()) + if paths_to_load: + from data_juicer.ops.base_op import OPERATORS + + before = set(OPERATORS.modules.keys()) + for path in paths_to_load: + try: + load_custom_operators([path]) + except Exception as exc: + msg = f"Failed to load custom op from '{path}': {exc}" + logger.error(msg) + warnings.append(msg) + after = set(OPERATORS.modules.keys()) + loaded = sorted(after - before) + + return {"loaded": loaded, "cleaned": cleaned, "warnings": warnings} + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="python -m data_juicer.utils.custom_op", + description="Data-Juicer Custom Operator Management Tool " "(does not affect built-in operators)", + ) + sub = parser.add_subparsers(dest="command", help="Available commands") + + # --- list --- + sub.add_parser( + "list", + help="List registered custom operators", + ) + + # --- register --- + p_reg = sub.add_parser( + "register", + help="Register custom operator(s) persistently", + ) + p_reg.add_argument( + "paths", + nargs="+", + help="Path(s) to custom operator file(s) or directory(ies)", + ) + + # --- unregister --- + p_unreg = sub.add_parser( + "unregister", + help="Unregister custom operator(s) by their registration path(s)", + ) + p_unreg.add_argument( + "paths", + nargs="+", + help="Registration path(s) to remove (file or directory)", + ) + + # --- reset --- + sub.add_parser( + "reset", + help="Clear all custom operators from the persistent registry", + ) + + return parser + + +def _cmd_list(args) -> int: + """List registered custom operators.""" + result = list_registered() + registrations = result.get("registrations", {}) + if not registrations: + print("No custom operators registered.") + return 0 + total_ops = sum(len(m.get("operators", [])) for m in registrations.values()) + print(f"Custom operators ({total_ops} op(s) from {len(registrations)} path(s)):") + for reg_path, meta in sorted(registrations.items()): + path_type = meta.get("type", "?") + reg_at = meta.get("registered_at", "?") + ops = meta.get("operators", []) + print(f" [{path_type}] {reg_path}") + print(f" registered_at: {reg_at}") + ops_str = ", ".join(sorted(ops)) if ops else "(none loaded)" + print(f" operators: {ops_str}") + return 0 + + +def _cmd_register(args) -> int: + """Register custom operator(s) persistently.""" + result = register_persistent(args.paths) + registered = result.get("registered", []) + skipped = result.get("skipped", []) + warnings = result.get("warnings", []) + + if registered: + print(f"Registered {len(registered)} operator(s):") + for name in registered: + print(f" + {name}") + + if skipped: + print(f"Skipped {len(skipped)} already-registered path(s):") + for p in skipped: + print(f" ~ {p}") + + if not registered and not skipped: + print("No new operators registered.") + + for warning in warnings: + print(f" WARNING: {warning}", file=sys.stderr) + + return 0 + + +def _cmd_unregister(args) -> int: + """Unregister custom operator(s) by their registration path(s).""" + result = unregister_paths(args.paths) + removed = result.get("removed", []) + not_found = result.get("not_found", []) + warnings = result.get("warnings", []) + + if removed: + print(f"Removed {len(removed)} registration(s):") + for p in removed: + print(f" - {p}") + + if not_found: + print(f"Not found ({len(not_found)}):") + for p in not_found: + print(f" ? {p}") + + for warning in warnings: + print(f" WARNING: {warning}", file=sys.stderr) + + return 0 + + +def _cmd_reset(args) -> int: + """Clear all custom operators from the persistent registry.""" + result = reset_registry() + removed = result.get("removed", []) + + if removed: + print(f"Removed {len(removed)} registration(s):") + for p in removed: + print(f" - {p}") + else: + print("Registry was already empty.") + return 0 + + +_COMMAND_MAP = { + "list": _cmd_list, + "register": _cmd_register, + "unregister": _cmd_unregister, + "reset": _cmd_reset, +} + + +def main(argv: Optional[List[str]] = None) -> int: + """CLI entry point for custom operator management.""" + parser = _build_parser() + args = parser.parse_args(argv) + + if not args.command: + parser.print_help() + return 1 + + handler = _COMMAND_MAP.get(args.command) + if handler is None: + parser.print_help() + return 1 + + return handler(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/data_juicer/utils/registry.py b/data_juicer/utils/registry.py index e438d015f9..c8612cea46 100644 --- a/data_juicer/utils/registry.py +++ b/data_juicer/utils/registry.py @@ -82,6 +82,17 @@ def _register_module(self, module_name=None, module_cls=None, force=False): self._modules[module_name] = module_cls module_cls._name = module_name + def unregister_module(self, module_name: str) -> bool: + """Remove a module from the registry by name. + + :param module_name: name of the module to remove + :return: True if removed, False if not found. + """ + if module_name in self._modules: + del self._modules[module_name] + return True + return False + def register_module(self, module_name: str = None, module_cls: type = None, force=False): """ Register module class object to registry with the specified modulename. diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index 669af22f21..075f73f0a8 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -2,6 +2,8 @@ - [How-to Guide for Developers](#how-to-guide-for-developers) - [1. Build Your Own OPs Quickly](#1-build-your-own-ops-quickly) + - [Persistent Custom Operator Registration](#persistent-custom-operator-registration) + - [Build an OP Step by Step](#build-an-op-step-by-step) - [2. Build Your Own Data Recipes and Configs](#2-build-your-own-data-recipes-and-configs) - [2.1 Fruitful Config Sources \& Type Hints](#21-fruitful-config-sources--type-hints) - [2.2 Hierarchical Configs and Helps](#22-hierarchical-configs-and-helps) @@ -27,6 +29,43 @@ > The development process of the following example takes directly adding operators in the corresponding module of the source code as an example. If an operator is added externally, the new operator can be registered by passing the parameter `--custom-operator-paths` or configuring the `custom_operator_paths` parameter in the yaml file, for example: `custom_operator_paths: ['/path/to/new/op.py', '/path/to/new/ops/directory/]`. +### Persistent Custom Operator Registration + +In addition to the per-run `--custom-operator-paths` approach above, Data-Juicer provides a **persistent custom operator registry** so that externally developed operators survive across processes and sessions without repeating configuration. + +The registry is stored at `~/.data_juicer/custom_op.json` (override with the `DJ_CUSTOM_OP_REGISTRY` environment variable). Manage it via the CLI: + +```bash +# Register custom operator(s) — accepts file or directory paths +python -m data_juicer.utils.custom_op register /path/to/my_mapper.py + +# List all registered custom operators +python -m data_juicer.utils.custom_op list + +# Unregister by registration path +python -m data_juicer.utils.custom_op unregister /path/to/my_mapper.py + +# Clear all custom operator registrations +python -m data_juicer.utils.custom_op reset +``` + +Once registered, custom operators are **automatically loaded** on every Data-Juicer startup. Stale entries (whose source files no longer exist) are cleaned up automatically. + +You can also search and inspect both built-in and custom operators with the operator search tool: + +```bash +# List all operators (built-in + custom) +python -m data_juicer.tools.op_search list + +# Show detailed info for a specific operator +python -m data_juicer.tools.op_search info my_mapper + +# Search operators by keyword +python -m data_juicer.tools.op_search search "text length" +``` + +### Build an OP Step by Step + Assuming we want to add a new Filter operator called "TextLengthFilter" to get corpus of expected text length, we can follow the following steps to build it. 1. (Optional) If the new OP defines some statistical variables, please add the corresponding new `StatsKeysConstant` attribute in `data_juicer/utils/constant.py` for unified management. @@ -120,7 +159,13 @@ process: max_len: 1000 ``` -6. Community contributors can submit corresponding operator PRs and work with the Data-Juicer team to gradually improve it in subsequent PRs. Please see more details [below](#4-contribution-to-the-open-source-community). We greatly welcome co-construction and will [highlight acknowledgements](https://github.com/datajuicer/data-juicer?tab=readme-ov-file#contribution-and-acknowledgements)! +6. (Optional) If you develop custom operators outside the Data-Juicer source tree, you can **persistently register** them so they are available across all future sessions without adding `custom_operator_paths` every time: + + ```bash + python -m data_juicer.utils.custom_op register /path/to/text_length_filter.py + ``` + +7. Community contributors can submit corresponding operator PRs and work with the Data-Juicer team to gradually improve it in subsequent PRs. Please see more details [below](#4-contribution-to-the-open-source-community). We greatly welcome co-construction and will [highlight acknowledgements](https://github.com/datajuicer/data-juicer?tab=readme-ov-file#contribution-and-acknowledgements)! ## 2. Build Your Own Data Recipes and Configs diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index 4bfa947685..820951210d 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -2,6 +2,8 @@ - [开发者指南](#开发者指南) - [1. 快速构建你自己的算子](#1-快速构建你自己的算子) + - [持久化自定义算子注册](#持久化自定义算子注册) + - [逐步构建一个算子](#逐步构建一个算子) - [2. 构建你自己的数据菜谱和配置项](#2-构建你自己的数据菜谱和配置项) - [2.1 丰富的配置源和类型提示](#21-丰富的配置源和类型提示) - [2.2 层次化的配置和帮助](#22-层次化的配置和帮助) @@ -27,6 +29,43 @@ > 以下示例的开发过程以直接在源码对应模块中添加算子为例。如果外部添加算子,可以通过传参`--custom-operator-paths` 或 yaml配置文件中配置`custom_operator_paths`参数注册新算子,例如:`custom_operator_paths: ['/path/to/new/op.py', '/path/to/new/ops/directory/]`。 +### 持久化自定义算子注册 + +除了上述每次运行时指定 `--custom-operator-paths` 的方式外,Data-Juicer 还提供了**持久化自定义算子注册表**,使外部开发的算子能够跨进程、跨会话持续生效,无需重复配置。 + +注册表存储在 `~/.data_juicer/custom_op.json`(可通过环境变量 `DJ_CUSTOM_OP_REGISTRY` 覆盖路径)。通过 CLI 管理: + +```bash +# 注册自定义算子 — 支持文件或目录路径 +python -m data_juicer.utils.custom_op register /path/to/my_mapper.py + +# 列出所有已注册的自定义算子 +python -m data_juicer.utils.custom_op list + +# 按注册路径取消注册 +python -m data_juicer.utils.custom_op unregister /path/to/my_mapper.py + +# 清除所有自定义算子注册 +python -m data_juicer.utils.custom_op reset +``` + +注册完成后,自定义算子会在每次 Data-Juicer 启动时**自动加载**。源文件已不存在的失效条目会被自动清理。 + +你还可以使用算子搜索工具查询和检视内置算子与自定义算子: + +```bash +# 列出所有算子(内置 + 自定义) +python -m data_juicer.tools.op_search list + +# 查看某个算子的详细信息 +python -m data_juicer.tools.op_search info my_mapper + +# 按关键词搜索算子 +python -m data_juicer.tools.op_search search "text length" +``` + +### 逐步构建一个算子 + 下面以 "TextLengthFilter" 的算子(过滤仅包含预期文本长度的样本语料)为例,展示相应开发构建过程。 1. (可选) 如果该算子定义了某个统计变量,那么请在 `data_juicer/utils/constant.py` 文件中添加一个新的`StatsKeys`属性来统一保存管理。 @@ -121,7 +160,13 @@ process: max_len: 1000 ``` -6. 社区贡献者可在alpha状态后就提相应算子PR。此后该贡献者可以与Data-Juicer团队一起在后续PR中,将其渐进完善到beta和stable版本。更多细节请参考下方第4节。我们非常欢迎共建,并会[高亮致谢](https://github.com/datajuicer/data-juicer?tab=readme-ov-file#contribution-and-acknowledgements)! +6. (可选)如果你在 Data-Juicer 源码树之外开发自定义算子,可以**持久化注册**它们,这样在后续所有会话中都可以直接使用,无需每次都添加 `custom_operator_paths`: + + ```bash + python -m data_juicer.utils.custom_op register /path/to/text_length_filter.py + ``` + +7. 社区贡献者可在alpha状态后就提相应算子PR。此后该贡献者可以与Data-Juicer团队一起在后续PR中,将其渐进完善到beta和stable版本。更多细节请参考下方第4节。我们非常欢迎共建,并会[高亮致谢](https://github.com/datajuicer/data-juicer?tab=readme-ov-file#contribution-and-acknowledgements)! ## 2. 构建你自己的数据菜谱和配置项 diff --git a/tests/tools/test_op_search.py b/tests/tools/test_op_search.py index 856c281e96..4de833f31e 100644 --- a/tests/tools/test_op_search.py +++ b/tests/tools/test_op_search.py @@ -1,7 +1,10 @@ +import os +import tempfile import unittest from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase from data_juicer.tools.op_search import OPSearcher +from data_juicer.tools.op_search import main as op_search_main class OPRecordTest(DataJuicerTestCaseBase): @@ -250,5 +253,82 @@ def test_records_map_deprecated_still_returns_all_ops(self): self.assertEqual(records_map, searcher.all_ops) +# --------------------------------------------------------------------------- +# CLI tests +# --------------------------------------------------------------------------- + + +class OpSearchCLIListTest(DataJuicerTestCaseBase): + """Test the 'list' CLI sub-command.""" + + def setUp(self): + self._tmp_dir = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp_dir.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp_dir.cleanup() + + def test_list(self): + rc = op_search_main(["list"]) + self.assertEqual(rc, 0) + +class OpSearchCLIInfoTest(DataJuicerTestCaseBase): + """Test the 'info' CLI sub-command.""" + + def setUp(self): + self._tmp_dir = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp_dir.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp_dir.cleanup() + + def test_info_builtin(self): + """Info on a built-in op should succeed.""" + rc = op_search_main(["info", "text_length_filter"]) + self.assertEqual(rc, 0) + + def test_info_not_found(self): + rc = op_search_main(["info", "no_such_op_xyz"]) + self.assertEqual(rc, 1) + +class OpSearchCLISearchTest(DataJuicerTestCaseBase): + """Test the 'search' CLI sub-command.""" + + def setUp(self): + self._tmp_dir = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp_dir.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp_dir.cleanup() + + def test_search_bm25(self): + rc = op_search_main(["search", "text length"]) + self.assertEqual(rc, 0) + + def test_search_regex(self): + rc = op_search_main(["search", "text.*filter", "--mode", "regex"]) + self.assertEqual(rc, 0) + + def test_search_by_tags(self): + rc = op_search_main(["search", "--tags", "cpu"]) + self.assertEqual(rc, 0) + + def test_search_by_type(self): + rc = op_search_main(["search", "--type", "mapper"]) + self.assertEqual(rc, 0) + +class OpSearchCLINoCommandTest(DataJuicerTestCaseBase): + """Test calling op_search CLI with no command.""" + + def test_no_command(self): + rc = op_search_main([]) + self.assertEqual(rc, 1) + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/test_custom_op.py b/tests/utils/test_custom_op.py new file mode 100644 index 0000000000..9fec0c05e4 --- /dev/null +++ b/tests/utils/test_custom_op.py @@ -0,0 +1,463 @@ +""" +Tests for data_juicer.utils.custom_op + +These tests exercise the persistent custom-op registry end-to-end: +register, unregister, reset, list, load, stale-entry cleanup, +and the CLI sub-commands. + +Each test uses a private temp directory for the registry JSON and a +throwaway custom-op source file so that nothing leaks between tests. +""" + +import os +import subprocess +import sys +import tempfile +import textwrap +import unittest + +from data_juicer.ops.base_op import OPERATORS +from data_juicer.utils.custom_op import ( + _read_registry, + _write_registry, + get_registry_path, + list_registered, + load_persistent_custom_ops, + main as custom_op_main, + register_persistent, + reset_registry, + unregister_paths, +) +from data_juicer.utils.registry import Registry +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +def _make_custom_op_file(tmp_dir: str, op_name: str = "my_test_mapper") -> str: + """Create a minimal custom mapper .py file and return its path.""" + code = textwrap.dedent(f"""\ + from data_juicer.ops import OPERATORS, Mapper + + @OPERATORS.register_module('{op_name}') + class MyTestMapper(Mapper): + \"\"\"A test custom mapper.\"\"\" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def process_single(self, sample): + return sample + """) + path = os.path.join(tmp_dir, f"{op_name}.py") + with open(path, "w", encoding="utf-8") as f: + f.write(code) + return path + + +# --------------------------------------------------------------------------- +# Unit tests for low-level API +# --------------------------------------------------------------------------- + + +class RegistryUnregisterTest(DataJuicerTestCaseBase): + """Test Registry.unregister_module (new method).""" + + def test_unregister_existing(self): + reg = Registry("test_unreg") + + class Dummy: + pass + + reg.register_module("dummy_op", Dummy) + self.assertIn("dummy_op", reg.modules) + result = reg.unregister_module("dummy_op") + self.assertTrue(result) + self.assertNotIn("dummy_op", reg.modules) + + def test_unregister_nonexistent(self): + reg = Registry("test_unreg2") + result = reg.unregister_module("no_such_op") + self.assertFalse(result) + + +class RegistryPathTest(DataJuicerTestCaseBase): + """Test get_registry_path with and without env override.""" + + def test_default_path(self): + env_backup = os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + try: + path = get_registry_path() + self.assertTrue(str(path).endswith("custom_op.json")) + self.assertIn(".data_juicer", str(path)) + finally: + if env_backup is not None: + os.environ["DJ_CUSTOM_OP_REGISTRY"] = env_backup + + def test_env_override(self): + with tempfile.TemporaryDirectory() as tmp: + custom_path = os.path.join(tmp, "custom_reg.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = custom_path + try: + self.assertEqual(str(get_registry_path()), custom_path) + finally: + del os.environ["DJ_CUSTOM_OP_REGISTRY"] + + +class ReadWriteRegistryTest(DataJuicerTestCaseBase): + """Test _read_registry / _write_registry helpers.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_read_empty(self): + data = _read_registry() + self.assertEqual(data["version"], 2) + self.assertEqual(data["registrations"], {}) + + def test_write_and_read(self): + payload = { + "version": 2, + "registrations": { + "/tmp/foo.py": { + "type": "file", + "registered_at": "2026-01-01T00:00:00", + } + }, + } + _write_registry(payload) + data = _read_registry() + self.assertEqual(data, payload) + + def test_read_malformed(self): + with open(self._reg_path, "w") as f: + f.write("not json") + data = _read_registry() + self.assertEqual(data["registrations"], {}) + + +class RegisterPersistentTest(DataJuicerTestCaseBase): + """Test register_persistent end-to-end.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + self._op_name = "test_reg_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_register_and_json(self): + result = register_persistent([self._op_file]) + self.assertIn(self._op_name, result["registered"]) + self.assertIn(self._op_name, OPERATORS.modules) + + # Registry should store the path, not operator names. + data = _read_registry() + abs_file = os.path.abspath(self._op_file) + self.assertIn(abs_file, data["registrations"]) + meta = data["registrations"][abs_file] + self.assertEqual(meta["type"], "file") + self.assertIn("registered_at", meta) + # No "operators" key in the persisted data. + self.assertNotIn("operators", meta) + + def test_register_nonexistent_path(self): + result = register_persistent(["/no/such/path.py"]) + self.assertEqual(result["registered"], []) + self.assertGreater(len(result["warnings"]), 0) + + +class UnregisterPathsTest(DataJuicerTestCaseBase): + """Test unregister_paths.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + self._op_name = "test_unreg_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + register_persistent([self._op_file]) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_unregister_existing(self): + abs_file = os.path.abspath(self._op_file) + result = unregister_paths([abs_file]) + self.assertIn(abs_file, result["removed"]) + self.assertEqual(result["not_found"], []) + self.assertNotIn(self._op_name, OPERATORS.modules) + + data = _read_registry() + self.assertNotIn(abs_file, data["registrations"]) + + def test_unregister_nonexistent(self): + result = unregister_paths(["/no/such/path.py"]) + self.assertEqual(result["removed"], []) + self.assertIn("/no/such/path.py", result["not_found"]) + + +class ResetRegistryTest(DataJuicerTestCaseBase): + """Test reset_registry.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + self._op_name = "test_reset_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + register_persistent([self._op_file]) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_reset_clears_all(self): + abs_file = os.path.abspath(self._op_file) + result = reset_registry() + self.assertIn(abs_file, result["removed"]) + self.assertNotIn(self._op_name, OPERATORS.modules) + + data = _read_registry() + self.assertEqual(data["registrations"], {}) + + +class ListRegisteredTest(DataJuicerTestCaseBase): + """Test list_registered.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_list_empty(self): + result = list_registered() + self.assertEqual(result["registrations"], {}) + self.assertEqual(result["custom_operators"], {}) + + def test_list_after_register(self): + op_name = "test_list_mapper" + op_file = _make_custom_op_file(self._tmp.name, op_name) + register_persistent([op_file]) + try: + result = list_registered() + # Check the flattened view (live from OPERATORS) + self.assertIn(op_name, result["custom_operators"]) + # Check the path-keyed view + abs_file = os.path.abspath(op_file) + self.assertIn(abs_file, result["registrations"]) + # The live "operators" list should contain the op + self.assertIn(op_name, result["registrations"][abs_file]["operators"]) + finally: + OPERATORS.unregister_module(op_name) + sys.modules.pop(op_name, None) + + +class LoadPersistentCustomOpsTest(DataJuicerTestCaseBase): + """Test load_persistent_custom_ops including stale-entry cleanup.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_load_valid(self): + op_name = "test_load_mapper" + op_file = _make_custom_op_file(self._tmp.name, op_name) + register_persistent([op_file]) + + OPERATORS.unregister_module(op_name) + sys.modules.pop(op_name, None) + self.assertNotIn(op_name, OPERATORS.modules) + + result = load_persistent_custom_ops() + self.assertIn(op_name, result["loaded"]) + self.assertIn(op_name, OPERATORS.modules) + + OPERATORS.unregister_module(op_name) + sys.modules.pop(op_name, None) + + def test_load_cleans_stale(self): + payload = { + "version": 2, + "registrations": { + "/no/such/file.py": { + "type": "file", + "registered_at": "2026-01-01T00:00:00", + } + }, + } + _write_registry(payload) + + result = load_persistent_custom_ops() + self.assertIn("/no/such/file.py", result["cleaned"]) + self.assertGreater(len(result["warnings"]), 0) + + data = _read_registry() + self.assertNotIn("/no/such/file.py", data["registrations"]) + + +class CrossProcessVisibilityTest(DataJuicerTestCaseBase): + """Test that a custom op registered in one process is visible in another.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + self._op_name = "test_xproc_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + register_persistent([self._op_file]) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_cross_process(self): + """Spawn a subprocess that loads the registry and checks the op.""" + script = textwrap.dedent(f"""\ + import os, sys, json + os.environ["DJ_CUSTOM_OP_REGISTRY"] = "{self._reg_path}" + from data_juicer.utils.custom_op import load_persistent_custom_ops + from data_juicer.ops.base_op import OPERATORS + result = load_persistent_custom_ops() + if "{self._op_name}" in OPERATORS.modules: + print("OK") + else: + print("FAIL") + sys.exit(1) + """) + proc = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=60, + ) + self.assertEqual(proc.returncode, 0, f"stderr: {proc.stderr}") + self.assertIn("OK", proc.stdout) + + +# --------------------------------------------------------------------------- +# CLI tests +# --------------------------------------------------------------------------- + + +class CustomOpCLIListTest(DataJuicerTestCaseBase): + """Test the 'list' CLI sub-command.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + + def tearDown(self): + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_list_empty(self): + rc = custom_op_main(["list"]) + self.assertEqual(rc, 0) + + +class CustomOpCLIRegisterTest(DataJuicerTestCaseBase): + """Test the 'register' CLI sub-command.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + self._op_name = "cli_reg_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_register(self): + rc = custom_op_main(["register", self._op_file]) + self.assertEqual(rc, 0) + self.assertIn(self._op_name, OPERATORS.modules) + + +class CustomOpCLIUnregisterTest(DataJuicerTestCaseBase): + """Test the 'unregister' CLI sub-command.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + self._op_name = "cli_unreg_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + register_persistent([self._op_file]) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_unregister(self): + abs_file = os.path.abspath(self._op_file) + rc = custom_op_main(["unregister", abs_file]) + self.assertEqual(rc, 0) + self.assertNotIn(self._op_name, OPERATORS.modules) + + +class CustomOpCLIResetTest(DataJuicerTestCaseBase): + """Test the 'reset' CLI sub-command.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory() + self._reg_path = os.path.join(self._tmp.name, "custom_op.json") + os.environ["DJ_CUSTOM_OP_REGISTRY"] = self._reg_path + self._op_name = "cli_reset_mapper" + self._op_file = _make_custom_op_file(self._tmp.name, self._op_name) + register_persistent([self._op_file]) + + def tearDown(self): + OPERATORS.unregister_module(self._op_name) + sys.modules.pop(self._op_name, None) + os.environ.pop("DJ_CUSTOM_OP_REGISTRY", None) + self._tmp.cleanup() + + def test_reset(self): + rc = custom_op_main(["reset"]) + self.assertEqual(rc, 0) + self.assertNotIn(self._op_name, OPERATORS.modules) + + +class CustomOpCLINoCommandTest(DataJuicerTestCaseBase): + """Test calling custom_op CLI with no command.""" + + def test_no_command(self): + rc = custom_op_main([]) + self.assertEqual(rc, 1) + + +if __name__ == "__main__": + unittest.main()