From 6fcb5a2adbc3afe69bce07a34e107e8dbafe42cc Mon Sep 17 00:00:00 2001 From: cmgzn Date: Wed, 8 Apr 2026 17:09:19 +0800 Subject: [PATCH 1/5] refactor: add schema module for declarative config definitions --- data_juicer/config/__init__.py | 22 + data_juicer/config/config.py | 673 +--------------- data_juicer/config/schema.py | 938 ++++++++++++++++++++++ data_juicer/core/data/config_validator.py | 80 +- data_juicer/core/data/dataset_builder.py | 8 +- data_juicer/core/data/load_strategy.py | 26 +- data_juicer/tools/DJ_mcp_recipe_flow.py | 114 +-- tests/config/test_config.py | 6 +- tests/core/data/test_dataset_builder.py | 9 +- 9 files changed, 1125 insertions(+), 751 deletions(-) create mode 100644 data_juicer/config/schema.py diff --git a/data_juicer/config/__init__.py b/data_juicer/config/__init__.py index 02fc413268..c647defac5 100644 --- a/data_juicer/config/__init__.py +++ b/data_juicer/config/__init__.py @@ -11,6 +11,18 @@ update_op_attr, validate_work_dir_config, ) +from .schema import ( + CheckpointConfig, + DatasetConfig, + DJConfig, + EventLoggingConfig, + IntermediateStorageConfig, + PartitionConfig, + ResourceOptimizationConfig, + flatten_nested_namespaces, + get_defaults, + get_json_schema, +) __all__ = [ "init_configs", @@ -24,4 +36,14 @@ "validate_work_dir_config", "resolve_job_id", "resolve_job_directories", + "DJConfig", + "DatasetConfig", + "CheckpointConfig", + "PartitionConfig", + "ResourceOptimizationConfig", + "IntermediateStorageConfig", + "EventLoggingConfig", + "flatten_nested_namespaces", + "get_defaults", + "get_json_schema", ] diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index d39b11e947..2ff5c448d2 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -22,7 +22,6 @@ namespace_to_dict, ) from jsonargparse._typehints import ActionTypeHint -from jsonargparse.typing import ClosedUnitInterval, NonNegativeInt, PositiveInt from loguru import logger from data_juicer.ops.base_op import OPERATORS @@ -102,6 +101,8 @@ def build_base_parser() -> ArgumentParser: parser = ArgumentParser(default_env=True, default_config_files=None, usage=argparse.SUPPRESS) # required but mutually exclusive args group + # These are parser-behavior arguments, not configuration fields, + # so they are registered manually rather than via the schema. required_group = parser.add_mutually_exclusive_group(required=True) required_group.add_argument("--config", action=ActionConfigFile, help="Path to a dj basic configuration file.") required_group.add_argument( @@ -114,638 +115,10 @@ def build_base_parser() -> ArgumentParser: "disabled. Only available for Analyzer.", ) - parser.add_argument( - "--auto_num", - type=PositiveInt, - default=1000, - help="The number of samples to be analyzed " "automatically. It's 1000 in default.", - ) - - parser.add_argument( - "--hpo_config", type=str, help="Path to a configuration file when using auto-HPO tool.", required=False - ) - parser.add_argument( - "--data_probe_algo", - type=str, - default="uniform", - help='Sampling algorithm to use. Options are "uniform", ' - '"frequency_specified_field_selector", or ' - '"topk_specified_field_selector". Default is "uniform". Only ' - "used for dataset sampling", - required=False, - ) - parser.add_argument( - "--data_probe_ratio", - type=ClosedUnitInterval, - default=1.0, - help="The ratio of the sample size to the original dataset size. " # noqa: E251 - "Default is 1.0 (no sampling). Only used for dataset sampling", - required=False, - ) - - # basic global paras with extended type hints - # e.g., files can be mode include flags - # "fr": "path to a file that exists and is readable") - # "fc": "path to a file that can be created if it does not exist") - # "dw": "path to a directory that exists and is writeable") - # "dc": "path to a directory that can be created if it does not exist") - # "drw": "path to a directory that exists and is readable and writeable") - parser.add_argument("--project_name", type=str, default="hello_world", help="Name of your data process project.") - parser.add_argument( - "--executor_type", - type=str, - default="default", - choices=["default", "ray", "ray_partitioned"], - help='Type of executor, support "default", "ray", or "ray_partitioned".', - ) - parser.add_argument( - "--dataset_path", - type=str, - default="", - help="Path to datasets with optional weights(0.0-1.0), 1.0 as " - "default. Accepted format: dataset1-path dataset2-path " - " dataset3-path ...", - ) - parser.add_argument( - "--dataset", - type=Union[List[Dict], Dict], - default=[], - help="Dataset setting to define local/remote datasets; could be a " # noqa: E251 - "dict or a list of dicts; refer to " - "https://datajuicer.github.io/data-juicer/en/main/docs/DatasetCfg.html for more " - "detailed examples", - ) - parser.add_argument( - "--generated_dataset_config", - type=Dict, - default=None, - help="Configuration used to create a dataset. " # noqa: E251 - "The dataset will be created from this configuration if provided. " - "It must contain the `--type` field to specify the dataset name.", - ) - parser.add_argument( - "--validators", - type=List[Dict], - default=[], - help="List of validators to apply to the dataset. Each validator " # noqa: E251 - "must have a `type` field specifying the validator type.", - ) - parser.add_argument( - "--load_dataset_kwargs", - type=Dict, - default={}, - help="Extra keyword arguments passed through to the underlying " # noqa: E251 - "datasets.load_dataset() call. Useful for format-specific " - "options such as chunksize (JSON), columns (Parquet), or " - "delimiter (CSV). See the HuggingFace Datasets docs for " - "available options.", - ) - parser.add_argument( - "--read_options", - type=Dict, - default={}, - help="Read options passed through to PyArrow reading functions " - "(e.g., block_size for JSON reading). This configuration is " - "especially useful when reading large JSON files.", - ) - parser.add_argument( - "--work_dir", - type=str, - default=None, - help="Path to a work directory to store outputs during Data-Juicer " # noqa: E251 - "running. It's the directory where export_path is at in default.", - ) - parser.add_argument( - "--export_path", - type=str, - default="./outputs/hello_world/hello_world.jsonl", - help="Path to export and save the output processed dataset. The " # noqa: E251 - "directory to store the processed dataset will be the work " - "directory of this process.", - ) - parser.add_argument( - "--export_type", - type=str, - default=None, - help="The export format type. If it's not specified, Data-Juicer will parse from the export_path. The " - "supported types can be found in Exporter._router() for standalone mode and " - "RayExporter._SUPPORTED_FORMATS for ray mode", - ) - parser.add_argument( - "--export_shard_size", - type=NonNegativeInt, - default=0, - help="Shard size of exported dataset in Byte. In default, it's 0, " # noqa: E251 - "which means export the whole dataset into only one file. If " - "it's set a positive number, the exported dataset will be split " - "into several sub-dataset shards, and the max size of each shard " - "won't larger than the export_shard_size", - ) - parser.add_argument( - "--export_in_parallel", - type=bool, - default=False, - help="Whether to export the result dataset in parallel to a single " # noqa: E251 - "file, which usually takes less time. It only works when " - "export_shard_size is 0, and its default number of processes is " - "the same as the argument np. **Notice**: If it's True, " - "sometimes exporting in parallel might require much more time " - "due to the IO blocking, especially for very large datasets. " - "When this happens, False is a better choice, although it takes " - "more time.", - ) - parser.add_argument( - "--export_extra_args", - type=Dict, - default={}, - help="Other optional arguments for exporting in dict. For example, the key mapping info for exporting " - "the WebDataset format.", - ) - parser.add_argument( - "--export_aws_credentials", - type=Dict, - default=None, - help="Export-specific AWS credentials for S3 export. If export_path is S3 and this is not provided, " - "an error will be raised. Should contain aws_access_key_id, aws_secret_access_key, aws_region, " - "and optionally aws_session_token and endpoint_url.", - ) - parser.add_argument( - "--keep_stats_in_res_ds", - type=bool, - default=False, - help="Whether to keep the computed stats in the result dataset. If " # noqa: E251 - "it's False, the intermediate fields to store the stats " - "computed by Filters will be removed. Default: False.", - ) - parser.add_argument( - "--keep_hashes_in_res_ds", - type=bool, - default=False, - help="Whether to keep the computed hashes in the result dataset. If " # noqa: E251 - "it's False, the intermediate fields to store the hashes " - "computed by Deduplicators will be removed. Default: False.", - ) - parser.add_argument("--np", type=PositiveInt, default=4, help="Number of processes to process dataset.") - parser.add_argument( - "--text_keys", - type=Union[str, List[str]], - default="text", - help="Key name of field where the sample texts to be processed, e.g., " # noqa: E251 - "`text`, `text.instruction`, `text.output`, ... Note: currently, " - "we support specify only ONE key for each op, for cases " - "requiring multiple keys, users can specify the op multiple " - "times. We will only use the first key of `text_keys` when you " - "set multiple keys.", - ) - parser.add_argument( - "--image_key", - type=str, - default="images", - help="Key name of field to store the list of sample image paths.", # noqa: E251 - ) - parser.add_argument( - "--image_bytes_key", - type=str, - default="image_bytes", - help="Key name of field to store the list of sample image bytes.", # noqa: E251 - ) - parser.add_argument( - "--image_special_token", - type=str, - default=SpecialTokens.image, - help="The special token that represents an image in the text. In " # noqa: E251 - 'default, it\'s "<__dj__image>". You can specify your own special' - " token according to your input dataset.", - ) - parser.add_argument( - "--audio_key", - type=str, - default="audios", - help="Key name of field to store the list of sample audio paths.", # noqa: E251 - ) - parser.add_argument( - "--audio_special_token", - type=str, - default=SpecialTokens.audio, - help="The special token that represents an audio in the text. In " # noqa: E251 - 'default, it\'s "<__dj__audio>". You can specify your own special' - " token according to your input dataset.", - ) - parser.add_argument( - "--video_key", - type=str, - default="videos", - help="Key name of field to store the list of sample video paths.", # noqa: E251 - ) - parser.add_argument( - "--video_special_token", - type=str, - default=SpecialTokens.video, - help="The special token that represents a video in the text. In " - 'default, it\'s "<__dj__video>". You can specify your own special' - " token according to your input dataset.", - ) - parser.add_argument( - "--eoc_special_token", - type=str, - default=SpecialTokens.eoc, - help="The special token that represents the end of a chunk in the " # noqa: E251 - 'text. In default, it\'s "<|__dj__eoc|>". You can specify your ' - "own special token according to your input dataset.", - ) - parser.add_argument( - "--suffixes", - type=Union[str, List[str]], - default=[], - help="Suffixes of files that will be found and loaded. If not set, we " # noqa: E251 - "will find all suffix files, and select a suitable formatter " - "with the most files as default.", - ) - parser.add_argument( - "--turbo", - type=bool, - default=False, - help="Enable Turbo mode to maximize processing speed when batch size " "is 1.", # noqa: E251 - ) - parser.add_argument( - "--skip_op_error", - type=bool, - default=True, - help="Skip errors in OPs caused by unexpected invalid samples.", # noqa: E251 - ) - parser.add_argument( - "--use_cache", - type=bool, - default=True, - help="Whether to use the cache management of huggingface datasets. It " # noqa: E251 - "might take up lots of disk space when using cache", - ) - parser.add_argument( - "--ds_cache_dir", - type=str, - default=None, - help="Cache dir for HuggingFace datasets. In default it's the same " # noqa: E251 - "as the environment variable `HF_DATASETS_CACHE`, whose default " - 'value is usually "~/.cache/huggingface/datasets". If this ' - "argument is set to a valid path by users, it will override the " - "default cache dir. Modifying this arg might also affect the other two" - " paths to store downloaded and extracted datasets that depend on " - "`HF_DATASETS_CACHE`", - ) - parser.add_argument( - "--cache_compress", - type=str, - default=None, - help="The compression method of the cache file, which can be" - 'specified in ["gzip", "zstd", "lz4"]. If this parameter is' - "None, the cache file will not be compressed.", - ) - parser.add_argument( - "--open_monitor", - type=bool, - default=False, - help="Whether to open the monitor to trace resource utilization for " # noqa: E251 - "each OP during data processing. It's False in default.", - ) - parser.add_argument( - "--use_checkpoint", - type=bool, - default=False, - help="Whether to use the checkpoint management to save the latest " # noqa: E251 - "version of dataset to work dir when processing. Rerun the same " - "config will reload the checkpoint and skip ops before it. Cache " - "will be disabled when it is true . If args of ops before the " - "checkpoint are changed, all ops will be rerun from the " - "beginning.", - ) - # Enhanced checkpoint configuration for PartitionedRayExecutor - parser.add_argument( - "--checkpoint.enabled", - type=bool, - default=True, - help="Enable enhanced checkpointing for PartitionedRayExecutor", - ) - parser.add_argument( - "--checkpoint.strategy", - type=str, - default="every_n_ops", - choices=["every_op", "every_partition", "every_n_ops", "manual", "disabled"], - help="Checkpoint strategy: every_n_ops (default, balanced), every_op (max protection), " - "manual (after specific ops), disabled (best performance)", - ) - parser.add_argument( - "--checkpoint.n_ops", - type=int, - default=5, - help="Number of operations between checkpoints for every_n_ops strategy. " - "Default 5 balances fault tolerance with Ray optimization.", - ) - parser.add_argument( - "--checkpoint.op_names", - type=List[str], - default=[], - help="List of operation names to checkpoint for manual strategy", - ) - # Event logging configuration - parser.add_argument( - "--event_logging.enabled", - type=bool, - default=True, - help="Enable event logging for job tracking and resumption", - ) - # Logging configuration - parser.add_argument( - "--max_log_size_mb", - type=int, - default=100, - help="Maximum log file size in MB before rotation", - ) - parser.add_argument( - "--backup_count", - type=int, - default=5, - help="Number of backup log files to keep", - ) - # Storage configuration - parser.add_argument( - "--event_log_dir", - type=str, - default=None, - help="Separate directory for event logs (fast storage)", - ) - parser.add_argument( - "--checkpoint_dir", - type=str, - default=None, - help="Separate directory for checkpoints (large storage)", - ) - # Job management - parser.add_argument( - "--job_id", - type=str, - default=None, - help="Custom job ID for resumption and tracking. If not provided, a unique ID will be auto-generated.", - ) - parser.add_argument( - "--temp_dir", - type=str, - default=None, - help="Path to the temp directory to store intermediate caches when " # noqa: E251 - "cache is disabled. In default it's None, so the temp dir will " - "be specified by system. NOTICE: you should be caution when " - "setting this argument because it might cause unexpected program " - "behaviors when this path is set to an unsafe directory.", - ) - parser.add_argument( - "--open_tracer", - type=bool, - default=False, - help="Whether to open the tracer to trace samples changed during " # noqa: E251 - "process. It might take more time when opening tracer.", - ) - parser.add_argument( - "--op_list_to_trace", - type=List[str], - default=[], - help="Which ops will be traced by tracer. If it's empty, all ops in " # noqa: E251 - "cfg.process will be traced. Only available when open_tracer is " - "true.", - ) - parser.add_argument( - "--trace_num", - type=int, - default=10, - help="Number of samples extracted by tracer to show the dataset " - "difference before and after a op. Only available when " - "open_tracer is true.", - ) - parser.add_argument( - "--trace_keys", - type=List[str], - default=[], - help="List of field names to include in trace output. If set, the " - "specified fields' values will be included in each trace entry. " - "Only available when open_tracer is true.", - ) - parser.add_argument( - "--open_insight_mining", - type=bool, - default=False, - help="Whether to open insight mining to trace the OP-wise stats/tags " # noqa: E251 - "changes during process. It might take more time when opening " - "insight mining.", - ) - parser.add_argument( - "--op_list_to_mine", - type=List[str], - default=[], - help="Which OPs will be applied on the dataset to mine the insights " # noqa: E251 - "in their stats changes. Only those OPs that produce stats or " - "meta are valid. If it's empty, all OPs that produce stats and " - "meta will be involved. Only available when open_insight_mining " - "is true.", - ) - parser.add_argument( - "--min_common_dep_num_to_combine", - type=int, - default=-1, - help="The minimum number of common dependencies required to determine whether to merge two operation " - "environment specifications. If set to -1, it means no combination of operation environments, where " - "every OP has its own runtime environment during processing without any merging. If set to >= 0, " - "environments of OPs that share at least min_common_dep_num_to_combine common dependencies will be " - "merged. It will open the operator environment manager to automatically analyze and merge runtime " - "environment for different OPs. It helps different OPs share and reuse the same runtime environment to " - "reduce resource utilization. It's -1 in default. Only available in ray mode. ", - ) - parser.add_argument( - "--conflict_resolve_strategy", - type=str, - default="split", - choices=["split", "overwrite", "latest"], - help="Strategy for resolving dependency conflicts, default is 'split' strategy. 'split': Keep the two " - "specs split when there is a conflict. 'overwrite': Overwrite the existing dependency with one " - "from the later OP. 'latest': Use the latest version of all specified dependency versions. " - "Only available when min_common_dep_num_to_combine >= 0.", - ) - parser.add_argument( - "--op_fusion", - type=bool, - default=False, - help="Whether to fuse operators that share the same intermediate " # noqa: E251 - "variables automatically. Op fusion might reduce the memory " - "requirements slightly but speed up the whole process.", - ) - parser.add_argument( - "--fusion_strategy", - type=str, - default="probe", - help='OP fusion strategy. Support ["greedy", "probe"] now. "greedy" ' # noqa: E251 - "means keep the basic OP order and put the fused OP to the last " - 'of each fused OP group. "probe" means Data-Juicer will probe ' - "the running speed for each OP at the beginning and reorder the " - "OPs and fused OPs according to their probed speed (fast to " - 'slow). It\'s "probe" in default.', - ) - parser.add_argument( - "--adaptive_batch_size", - type=bool, - default=False, - help="Whether to use adaptive batch sizes for each OP according to " # noqa: E251 - "the probed results. It's False in default.", - ) - parser.add_argument( - "--process", - type=List[Dict], - default=[], - help="List of several operators with their arguments, these ops will " # noqa: E251 - "be applied to dataset in order", - ) - parser.add_argument( - "--percentiles", - type=List[float], - default=[], - help="Percentiles to analyze the dataset distribution. Only used in " "Analysis.", # noqa: E251 - ) - parser.add_argument( - "--export_original_dataset", - type=bool, - default=False, - help="whether to export the original dataset with stats. If you only " # noqa: E251 - "need the stats of the dataset, setting it to false could speed " - "up the exporting.", - ) - parser.add_argument( - "--save_stats_in_one_file", - type=bool, - default=False, - help="Whether to save all stats to only one file. Only used in " "Analysis.", - ) - parser.add_argument("--ray_address", type=str, default="auto", help="The address of the Ray cluster.") - - # Partitioning configuration for PartitionedRayExecutor - # Support both flat and nested partition configuration - parser.add_argument( - "--partition_size", - type=int, - default=10000, - help="Number of samples per partition for PartitionedRayExecutor (legacy flat config)", - ) - parser.add_argument( - "--max_partition_size_mb", - type=int, - default=128, - help="Maximum partition size in MB for PartitionedRayExecutor (legacy flat config)", - ) - - parser.add_argument( - "--preserve_intermediate_data", - type=bool, - default=False, - help="Preserve intermediate data for debugging (legacy flat config)", - ) - - # partition configuration - parser.add_argument( - "--partition.mode", - type=str, - default="auto", - choices=["manual", "auto"], - help="Partition mode: manual (specify num_of_partitions) or auto (use partition size optimizer)", - ) - parser.add_argument( - "--partition.num_of_partitions", - type=int, - default=4, - help="Number of partitions for manual mode (ignored in auto mode)", - ) - parser.add_argument( - "--partition.target_size_mb", - type=int, - default=256, - help="Target partition size in MB for auto mode (128, 256, 512, or 1024). " - "Controls how large each partition should be. Smaller = more checkpoints & better recovery, " - "larger = less overhead. Default 256MB balances memory safety and efficiency.", - ) + # Register all configuration fields from the declarative schema + from data_juicer.config.schema import register_schema_to_parser - # Resource optimization configuration - parser.add_argument( - "--resource_optimization.auto_configure", - type=bool, - default=False, - help="Enable automatic optimization of partition size, worker count, and other resource-dependent settings (nested resource_optimization config)", - ) - - # Intermediate storage configuration - parser.add_argument( - "--intermediate_storage.preserve_intermediate_data", - type=bool, - default=False, - help="Preserve intermediate data for debugging (nested intermediate_storage config)", - ) - parser.add_argument( - "--intermediate_storage.cleanup_temp_files", - type=bool, - default=True, - help="Clean up temporary files after processing (nested intermediate_storage config)", - ) - parser.add_argument( - "--intermediate_storage.cleanup_on_success", - type=bool, - default=False, - help="Clean up intermediate files even on successful completion (nested intermediate_storage config)", - ) - parser.add_argument( - "--intermediate_storage.retention_policy", - type=str, - default="keep_all", - choices=["keep_all", "keep_failed_only", "cleanup_all"], - help="File retention policy (nested intermediate_storage config)", - ) - parser.add_argument( - "--intermediate_storage.max_retention_days", - type=int, - default=7, - help="Maximum retention days for files (nested intermediate_storage config)", - ) - - # Intermediate storage format configuration - parser.add_argument( - "--intermediate_storage.format", - type=str, - default="parquet", - choices=["parquet", "arrow", "jsonl"], - help="Storage format for checkpoints and intermediate data (nested intermediate_storage config)", - ) - parser.add_argument( - "--intermediate_storage.compression", - type=str, - default="snappy", - choices=["snappy", "gzip", "none"], - help="Compression format for storage files (nested intermediate_storage config)", - ) - - parser.add_argument( - "--intermediate_storage.write_partitions", - type=bool, - default=True, - help="Whether to write intermediate partition files to disk (nested intermediate_storage config). Set to false for better performance when intermediate files aren't needed.", - ) - - parser.add_argument( - "--partition_dir", - type=str, - default=None, - help="Directory to store partition files. Supports {work_dir} placeholder. If not set, defaults to {work_dir}/partitions.", - ) - - parser.add_argument("--custom-operator-paths", nargs="+", help="Paths to custom operator scripts or directories.") - parser.add_argument("--debug", action="store_true", help="Whether to run in debug mode.") - parser.add_argument( - "--auto_op_parallelism", - type=bool, - default=True, - help="Whether to automatically set operator parallelism.", - ) + register_schema_to_parser(parser) return parser @@ -852,6 +225,13 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None, l if not load_configs_only: display_config(cfg) + # Convert nested Namespace objects (from Pydantic sub-models) + # back to plain dicts so downstream code can keep using + # dict-access patterns (e.g. cfg.dataset["configs"]). + from data_juicer.config.schema import flatten_nested_namespaces + + flatten_nested_namespaces(cfg) + global global_cfg, global_parser global_cfg = cfg global_parser = parser @@ -1720,28 +1100,17 @@ def get_init_configs(cfg: Union[Namespace, Dict], load_configs_only: bool = True def get_default_cfg(): - """Get default config values from config_min.yaml""" - cfg = Namespace() + """Get default config values from schema definitions. - # Get path to config_min.yaml - config_dir = os.path.dirname(os.path.abspath(__file__)) - default_config_path = os.path.join(config_dir, "config_min.yaml") - - # Load default values from yaml - with open(default_config_path, "r", encoding="utf-8") as f: - defaults = yaml.safe_load(f) - - # Convert to flat dictionary for namespace - flat_defaults = { - # Add other top-level keys from config_min.yaml - **defaults - } - - # Update cfg with defaults - for key, value in flat_defaults.items(): - if not hasattr(cfg, key): - setattr(cfg, key, value) + Returns a Namespace populated with the default value for every + configuration field defined in the schema. This replaces the + previous approach of reading a subset from config_min.yaml. + """ + from data_juicer.config.schema import get_defaults + cfg = Namespace() + for key, value in get_defaults().items(): + setattr(cfg, key, value) return cfg diff --git a/data_juicer/config/schema.py b/data_juicer/config/schema.py new file mode 100644 index 0000000000..2ccf9266fd --- /dev/null +++ b/data_juicer/config/schema.py @@ -0,0 +1,938 @@ +""" +Declarative schema definitions for Data-Juicer global configuration. + +This module serves as the **single source of truth** for all configuration +fields. The entire config is defined as a single Pydantic ``DJConfig`` +model with nested sub-models for structured sections (dataset, checkpoint, +partition, etc.). + +External consumers can use the query APIs to: +- Get the JSON Schema: ``get_json_schema()`` +- Get all default values: ``get_defaults()`` +- Get the model class itself: ``DJConfig`` + +The ``register_schema_to_parser()`` function bridges the Pydantic model +back to jsonargparse's ``ArgumentParser``. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Literal, Optional, Union + +from jsonargparse import ArgumentParser +from pydantic import BaseModel, Field + +from data_juicer.utils.mm_utils import SpecialTokens + +# ============================================================ +# Pydantic sub-model definitions +# ============================================================ + + +class DatasetConfig(BaseModel): + """Configuration for dataset loading. + + The ``configs`` field is a list of per-source dataset configurations + (similar to how ``process`` holds per-operator configurations). + Each item's schema depends on its ``type`` and the corresponding + ``DataLoadStrategy`` subclass. + + Common fields shared by all strategies (e.g. ``weight``, ``type``) + are defined in ``DataLoadStrategy.BASE_CONFIG_RULES``. + Strategy-specific fields are defined in each subclass's + ``CONFIG_VALIDATION_RULES``. Available strategies can be discovered + via ``DataLoadStrategyRegistry``. + """ + + max_sample_num: Optional[int] = Field( + default=None, + description=( + "Maximum number of samples to load from the dataset configs. " + "When set, samples are drawn from each dataset source according " + "to its weight." + ), + gt=0, + ) + configs: List[Dict] = Field( + default_factory=list, + description=( + "List of dataset source configurations. Each entry is a dict " + "whose schema depends on its 'type' field (e.g. 'local', " + "'huggingface', 'modelscope'). Common fields for all strategies " + "(e.g. 'weight', 'type') are defined in " + "DataLoadStrategy.BASE_CONFIG_RULES. Strategy-specific fields " + "are defined in each strategy's CONFIG_VALIDATION_RULES. " + "Available strategies can be discovered via " + "DataLoadStrategyRegistry." + ), + ) + + +class CheckpointConfig(BaseModel): + """Enhanced checkpoint configuration for PartitionedRayExecutor.""" + + enabled: bool = Field( + default=True, + description="Enable enhanced checkpointing for PartitionedRayExecutor", + ) + strategy: Literal["every_op", "every_partition", "every_n_ops", "manual", "disabled"] = Field( + default="every_n_ops", + description=( + "Checkpoint strategy: every_n_ops (default, balanced), every_op " + "(max protection), manual (after specific ops), disabled (best " + "performance)" + ), + ) + n_ops: int = Field( + default=5, + description=( + "Number of operations between checkpoints for every_n_ops " + "strategy. Default 5 balances fault tolerance with Ray " + "optimization." + ), + ) + op_names: List[str] = Field( + default_factory=list, + description="List of operation names to checkpoint for manual strategy", + ) + + +class PartitionConfig(BaseModel): + """Partition configuration for PartitionedRayExecutor.""" + + mode: Literal["manual", "auto"] = Field( + default="auto", + description=("Partition mode: manual (specify num_of_partitions) or auto " "(use partition size optimizer)"), + ) + num_of_partitions: int = Field( + default=4, + description="Number of partitions for manual mode (ignored in auto mode)", + ) + target_size_mb: int = Field( + default=256, + description=( + "Target partition size in MB for auto mode (128, 256, 512, or " + "1024). Controls how large each partition should be. Smaller = " + "more checkpoints & better recovery, larger = less overhead. " + "Default 256MB balances memory safety and efficiency." + ), + ) + + +class ResourceOptimizationConfig(BaseModel): + """Resource optimization configuration.""" + + auto_configure: bool = Field( + default=False, + description=( + "Enable automatic optimization of partition size, worker count, " "and other resource-dependent settings" + ), + ) + + +class IntermediateStorageConfig(BaseModel): + """Intermediate storage configuration.""" + + preserve_intermediate_data: bool = Field( + default=False, + description="Preserve intermediate data for debugging", + ) + cleanup_temp_files: bool = Field( + default=True, + description="Clean up temporary files after processing", + ) + cleanup_on_success: bool = Field( + default=False, + description="Clean up intermediate files even on successful completion", + ) + retention_policy: Literal["keep_all", "keep_failed_only", "cleanup_all"] = Field( + default="keep_all", + description="File retention policy", + ) + max_retention_days: int = Field( + default=7, + description="Maximum retention days for files", + ) + format: Literal["parquet", "arrow", "jsonl"] = Field( + default="parquet", + description="Storage format for checkpoints and intermediate data", + ) + compression: Literal["snappy", "gzip", "none"] = Field( + default="snappy", + description="Compression format for storage files", + ) + write_partitions: bool = Field( + default=True, + description=( + "Whether to write intermediate partition files to disk. " + "Set to false for better performance when intermediate " + "files aren't needed." + ), + ) + + +class EventLoggingConfig(BaseModel): + """Event logging configuration.""" + + enabled: bool = Field( + default=True, + description="Enable event logging for job tracking and resumption", + ) + + +# ============================================================ +# DJConfig: the single Pydantic model for all configuration +# ============================================================ + +# Metadata dict for fields that need special argparse handling. +# Stored in Field.json_schema_extra so register_schema_to_parser +# can read it. +_ARGPARSE_ACTION = "argparse_action" +_ARGPARSE_NARGS = "argparse_nargs" + + +class DJConfig(BaseModel): + """Complete Data-Juicer configuration. + + This is the **single source of truth** for all configuration fields. + ``DJConfig.model_json_schema()`` produces a full JSON Schema that + external agents can consume directly. + """ + + # -------------------------------------------------------- + # general: basic global parameters + # -------------------------------------------------------- + project_name: str = Field( + default="hello_world", + description="Name of your data process project.", + ) + executor_type: Literal["default", "ray", "ray_partitioned"] = Field( + default="default", + description='Type of executor, support "default", "ray", or "ray_partitioned".', + ) + np: int = Field( + default=4, + gt=0, + description="Number of processes to process dataset.", + ) + turbo: bool = Field( + default=False, + description="Enable Turbo mode to maximize processing speed when batch size is 1.", + ) + skip_op_error: bool = Field( + default=True, + description="Skip errors in OPs caused by unexpected invalid samples.", + ) + auto_op_parallelism: bool = Field( + default=True, + description="Whether to automatically set operator parallelism.", + ) + debug: bool = Field( + default=False, + description="Whether to run in debug mode.", + json_schema_extra={_ARGPARSE_ACTION: "store_true"}, + ) + custom_operator_paths: Optional[List[str]] = Field( + default=None, + description=( + "Paths to custom operator scripts. Multiple paths can be " + "specified. Operators defined in these scripts will be " + "registered and available for use in the process pipeline." + ), + json_schema_extra={_ARGPARSE_NARGS: "+"}, + ) + use_dag: Optional[bool] = Field( + default=None, + description=( + "Whether to enable DAG execution planning. If None (default), " + "DAG execution is automatically enabled for distributed/" + "partitioned executors and disabled for standalone mode. Set " + "to True to force-enable DAG execution monitoring, or False " + "to disable it." + ), + ) + + # -------------------------------------------------------- + # dataset: data input configuration + # -------------------------------------------------------- + dataset_path: str = Field( + default="", + description=( + "Path to datasets with optional weights(0.0-1.0), 1.0 as " + "default. Accepted format: dataset1-path dataset2-path " + " dataset3-path ..." + ), + ) + dataset: Optional[DatasetConfig] = Field( + default=None, + description=( + "Dataset setting to define local/remote datasets. Contains " + "max_sample_num (optional max samples to load) and configs " + "(list of per-source dataset configurations). Refer to " + "https://datajuicer.github.io/data-juicer/en/main/docs/DatasetCfg.html " + "for more detailed examples. " + ), + ) + generated_dataset_config: Optional[Dict] = Field( + default=None, + description=( + "Configuration used to create a dataset. " + "The dataset will be created from this configuration if provided. " + "It must contain the `type` field to specify the dataset name." + ), + ) + validators: List[Dict] = Field( + default_factory=list, + description=( + "List of validators to apply to the dataset. Each validator " + "must have a `type` field specifying the validator type." + ), + ) + load_dataset_kwargs: Dict = Field( + default_factory=dict, + description=( + "Extra keyword arguments passed through to the underlying " + "datasets.load_dataset() call. Useful for format-specific " + "options such as chunksize (JSON), columns (Parquet), or " + "delimiter (CSV). See the HuggingFace Datasets docs for " + "available options." + ), + ) + suffixes: List[str] = Field( + default_factory=list, + description=( + "Suffixes of files that will be find and loaded. If not set, " + "we will find all suffix files under the dataset_path." + ), + ) + add_suffix: bool = Field( + default=False, + description=( + "Whether to add the file suffix to dataset meta info. If " + "True, a '__dj__suffix' field will be added to each sample " + "indicating which file type it came from. This is " + "automatically enabled when suffix_filter is used in the " + "process list, but can also be manually set to True." + ), + ) + + # -------------------------------------------------------- + # export: data output / export configuration + # -------------------------------------------------------- + export_path: str = Field( + default="./outputs/hello_world/hello_world.jsonl", + description=( + "Path to export and save the output processed dataset. The " + "directory to store the processed dataset will be the work " + "directory of this process." + ), + ) + export_type: Optional[str] = Field( + default=None, + description=( + "The export format type. If it's not specified, Data-Juicer will " + "parse from the export_path. The supported types can be found in " + "Exporter._router() for standalone mode and " + "RayExporter._SUPPORTED_FORMATS for ray mode" + ), + ) + export_shard_size: int = Field( + default=0, + ge=0, + description=( + "Shard size of exported dataset in Byte. In default, it's 0, " + "which means export the whole dataset into only one file. If " + "it's set a positive number, the exported dataset will be split " + "into several sub-dataset shards, and the max size of each shard " + "won't larger than the export_shard_size" + ), + ) + export_in_parallel: bool = Field( + default=False, + description=( + "Whether to export the result dataset in parallel to a single " + "file, which usually takes less time. It only works when " + "export_shard_size is 0, and its default number of processes is " + "the same as the argument np. **Notice**: If it's True, " + "sometimes exporting in parallel might require much more time " + "due to the IO blocking, especially for very large datasets. " + "When this happens, False is a better choice, although it takes " + "more time." + ), + ) + export_extra_args: Dict = Field( + default_factory=dict, + description=( + "Other optional arguments for exporting in dict. For example, " + "the key mapping info for exporting the WebDataset format." + ), + ) + export_aws_credentials: Optional[Dict] = Field( + default=None, + description=( + "Export-specific AWS credentials for S3 export. If export_path " + "is S3 and this is not provided, an error will be raised. Should " + "contain aws_access_key_id, aws_secret_access_key, aws_region, " + "and optionally aws_session_token and endpoint_url." + ), + ) + keep_stats_in_res_ds: bool = Field( + default=False, + description=( + "Whether to keep the computed stats in the result dataset. If " + "it's False, the intermediate fields to store the stats " + "computed by Filters will be removed. Default: False." + ), + ) + keep_hashes_in_res_ds: bool = Field( + default=False, + description=( + "Whether to keep the computed hashes in the result dataset. If " + "it's False, the intermediate fields to store the hashes " + "computed by Deduplicators will be removed. Default: False." + ), + ) + + # -------------------------------------------------------- + # multimodal: multimodal data processing keys & tokens + # -------------------------------------------------------- + text_keys: Union[str, List[str]] = Field( + default="text", + description=( + "Key name of field where the sample texts to be processed, e.g., " + "`text`, `text.instruction`, `text.output`, ... Note: currently, " + "we support specify only ONE key for each op, for cases " + "requiring multiple keys, users can specify the op multiple " + "times. We will only use the first key of `text_keys` when you " + "set multiple keys." + ), + ) + image_key: str = Field( + default="images", + description="Key name of field to store the list of sample image paths.", + ) + image_bytes_key: str = Field( + default="image_bytes", + description="Key name of field to store the list of sample image bytes.", + ) + image_special_token: str = Field( + default=SpecialTokens.image, + description=( + "The special token that represents an image in the text. In " + 'default, it\'s "<__dj__image>". You can specify your own special' + " token according to your input dataset." + ), + ) + audio_key: str = Field( + default="audios", + description="Key name of field to store the list of sample audio paths.", + ) + audio_special_token: str = Field( + default=SpecialTokens.audio, + description=( + "The special token that represents an audio in the text. In " + 'default, it\'s "<__dj__audio>". You can specify your own special' + " token according to your input dataset." + ), + ) + video_key: str = Field( + default="videos", + description="Key name of field to store the list of sample video paths.", + ) + video_special_token: str = Field( + default=SpecialTokens.video, + description=( + "The special token that represents a video in the text. In " + 'default, it\'s "<__dj__video>". You can specify your own special' + " token according to your input dataset." + ), + ) + eoc_special_token: str = Field( + default=SpecialTokens.eoc, + description=( + "The special token that represents the end of a chunk in the " + 'text. In default, it\'s "<|__dj__eoc|>". You can specify your ' + "own special token according to your input dataset." + ), + ) + + # -------------------------------------------------------- + # cache: cache management + # -------------------------------------------------------- + use_cache: bool = Field( + default=True, + description=( + "Whether to use the cache management of huggingface datasets. It " + "might take up lots of disk space when using cache" + ), + ) + ds_cache_dir: Optional[str] = Field( + default=None, + description=( + "Cache dir for HuggingFace datasets. In default it's the same " + "as the environment variable `HF_DATASETS_CACHE`, whose default " + 'value is usually "~/.cache/huggingface/datasets". If this ' + "argument is set to a valid path by users, it will override the " + "default cache dir. Modifying this arg might also affect the " + "other two paths to store downloaded and extracted datasets that " + "depend on `HF_DATASETS_CACHE`" + ), + ) + cache_compress: Optional[str] = Field( + default=None, + description=( + "The compression method of the cache file, which can be " + 'specified in ["gzip", "zstd", "lz4"]. If this parameter is ' + "None, the cache file will not be compressed." + ), + ) + temp_dir: Optional[str] = Field( + default=None, + description=( + "Path to the temp directory to store intermediate caches when " + "cache is disabled. In default it's None, so the temp dir will " + "be specified by system. NOTICE: you should be caution when " + "setting this argument because it might cause unexpected program " + "behaviors when this path is set to an unsafe directory." + ), + ) + + # -------------------------------------------------------- + # checkpoint: checkpoint configuration + # -------------------------------------------------------- + use_checkpoint: bool = Field( + default=False, + description=( + "Whether to use the checkpoint management to save the latest " + "version of dataset to work dir when processing. Rerun the same " + "config will reload the checkpoint and skip ops before it. Cache " + "will be disabled when it is true . If args of ops before the " + "checkpoint are changed, all ops will be rerun from the " + "beginning." + ), + ) + checkpoint: CheckpointConfig = Field( + default_factory=CheckpointConfig, + description="Enhanced checkpoint configuration for PartitionedRayExecutor.", + ) + + # -------------------------------------------------------- + # tracer: tracing and insight mining + # -------------------------------------------------------- + open_monitor: bool = Field( + default=False, + description=( + "Whether to open the monitor to trace resource utilization for " + "each OP during data processing. It's False in default." + ), + ) + open_tracer: bool = Field( + default=False, + description=( + "Whether to open the tracer to trace samples changed during " + "process. It might take more time when opening tracer." + ), + ) + op_list_to_trace: List[str] = Field( + default_factory=list, + description=( + "Which ops will be traced by tracer. If it's empty, all ops in " + "cfg.process will be traced. Only available when open_tracer is " + "true." + ), + ) + trace_num: int = Field( + default=10, + description=( + "Number of samples extracted by tracer to show the dataset " + "difference before and after a op. Only available when " + "open_tracer is true." + ), + ) + trace_keys: List[str] = Field( + default_factory=list, + description=( + "List of field names to include in trace output. If set, the " + "specified fields' values will be included in each trace entry. " + "Only available when open_tracer is true." + ), + ) + open_insight_mining: bool = Field( + default=False, + description=( + "Whether to open insight mining to trace the OP-wise stats/tags " + "changes during process. It might take more time when opening " + "insight mining." + ), + ) + op_list_to_mine: List[str] = Field( + default_factory=list, + description=( + "Which OPs will be applied on the dataset to mine the insights " + "in their stats changes. Only those OPs that produce stats or " + "meta are valid. If it's empty, all OPs that produce stats and " + "meta will be involved. Only available when open_insight_mining " + "is true." + ), + ) + + # -------------------------------------------------------- + # op_management: operator fusion and environment management + # -------------------------------------------------------- + op_fusion: bool = Field( + default=False, + description=( + "Whether to fuse operators that share the same intermediate " + "variables automatically. Op fusion might reduce the memory " + "requirements slightly but speed up the whole process." + ), + ) + fusion_strategy: str = Field( + default="probe", + description=( + 'OP fusion strategy. Support ["greedy", "probe"] now. "greedy" ' + "means keep the basic OP order and put the fused OP to the last " + 'of each fused OP group. "probe" means Data-Juicer will probe ' + "the running speed for each OP at the beginning and reorder the " + "OPs and fused OPs according to their probed speed (fast to " + 'slow). It\'s "probe" in default.' + ), + ) + adaptive_batch_size: bool = Field( + default=False, + description=( + "Whether to use adaptive batch sizes for each OP according to " "the probed results. It's False in default." + ), + ) + min_common_dep_num_to_combine: int = Field( + default=-1, + description=( + "The minimum number of common dependencies required to determine " + "whether to merge two operation environment specifications. If " + "set to -1, it means no combination of operation environments, " + "where every OP has its own runtime environment during processing " + "without any merging. If set to >= 0, environments of OPs that " + "share at least min_common_dep_num_to_combine common dependencies " + "will be merged. It will open the operator environment manager to " + "automatically analyze and merge runtime environment for " + "different OPs. It helps different OPs share and reuse the same " + "runtime environment to reduce resource utilization. It's -1 in " + "default. Only available in ray mode." + ), + ) + conflict_resolve_strategy: Literal["split", "overwrite", "latest"] = Field( + default="split", + description=( + "Strategy for resolving dependency conflicts, default is 'split' " + "strategy. 'split': Keep the two specs split when there is a " + "conflict. 'overwrite': Overwrite the existing dependency with " + "one from the later OP. 'latest': Use the latest version of all " + "specified dependency versions. Only available when " + "min_common_dep_num_to_combine >= 0." + ), + ) + + # -------------------------------------------------------- + # process: processing pipeline + # -------------------------------------------------------- + process: List[Dict] = Field( + default_factory=list, + description=( + "List of several operators with their arguments, these ops will " "be applied to dataset in order" + ), + ) + + # -------------------------------------------------------- + # distributed: Ray configuration + # -------------------------------------------------------- + ray_address: str = Field( + default="auto", + description="The address of the Ray cluster.", + ) + + # -------------------------------------------------------- + # partition: partitioning configuration + # -------------------------------------------------------- + partition_size: int = Field( + default=10000, + description=("Number of samples per partition for PartitionedRayExecutor " "(legacy flat config)"), + ) + max_partition_size_mb: int = Field( + default=128, + description=("Maximum partition size in MB for PartitionedRayExecutor " "(legacy flat config)"), + ) + preserve_intermediate_data: bool = Field( + default=False, + description="Preserve intermediate data for debugging (legacy flat config)", + ) + partition: PartitionConfig = Field( + default_factory=PartitionConfig, + description="Partition configuration for PartitionedRayExecutor.", + ) + partition_dir: Optional[str] = Field( + default=None, + description=( + "Directory to store partition files. Supports {work_dir} " + "placeholder. If not set, defaults to {work_dir}/partitions." + ), + ) + + # -------------------------------------------------------- + # storage: intermediate storage and resource optimization + # -------------------------------------------------------- + checkpoint_dir: Optional[str] = Field( + default=None, + description="Separate directory for checkpoints (large storage)", + ) + resource_optimization: ResourceOptimizationConfig = Field( + default_factory=ResourceOptimizationConfig, + description="Resource optimization configuration.", + ) + intermediate_storage: IntermediateStorageConfig = Field( + default_factory=IntermediateStorageConfig, + description="Intermediate storage configuration.", + ) + + # -------------------------------------------------------- + # logging: event logging and log management + # -------------------------------------------------------- + event_logging: EventLoggingConfig = Field( + default_factory=EventLoggingConfig, + description="Event logging configuration.", + ) + max_log_size_mb: int = Field( + default=100, + description="Maximum log file size in MB before rotation", + ) + backup_count: int = Field( + default=5, + description="Number of backup log files to keep", + ) + event_log_dir: Optional[str] = Field( + default=None, + description="Separate directory for event logs (fast storage)", + ) + + # -------------------------------------------------------- + # job: job management + # -------------------------------------------------------- + job_id: Optional[str] = Field( + default=None, + description=( + "Custom job ID for resumption and tracking. If not provided, " "a unique ID will be auto-generated." + ), + ) + work_dir: Optional[str] = Field( + default=None, + description=( + "Path to a work directory to store outputs during Data-Juicer " + "running. It's the directory where export_path is at in default." + ), + ) + + # -------------------------------------------------------- + # analysis: only for data analysis + # -------------------------------------------------------- + percentiles: List[float] = Field( + default_factory=list, + description=("Percentiles to analyze the dataset distribution. Only used in " "Analysis."), + ) + export_original_dataset: bool = Field( + default=False, + description=( + "whether to export the original dataset with stats. If you only " + "need the stats of the dataset, setting it to false could speed " + "up the exporting." + ), + ) + save_stats_in_one_file: bool = Field( + default=False, + description=("Whether to save all stats to only one file. Only used in " "Analysis."), + ) + + # -------------------------------------------------------- + # sampling: sandbox / HPO + # -------------------------------------------------------- + auto_num: int = Field( + default=1000, + gt=0, + description="The number of samples to be analyzed automatically. It's 1000 in default.", + ) + hpo_config: Optional[str] = Field( + default=None, + description="Path to a configuration file when using auto-HPO tool.", + ) + data_probe_algo: str = Field( + default="uniform", + description=( + 'Sampling algorithm to use. Options are "uniform", ' + '"frequency_specified_field_selector", or ' + '"topk_specified_field_selector". Default is "uniform". Only ' + "used for dataset sampling" + ), + ) + data_probe_ratio: float = Field( + default=1.0, + ge=0.0, + le=1.0, + description=( + "The ratio of the sample size to the original dataset size. " + "Default is 1.0 (no sampling). Only used for dataset sampling" + ), + ) + + +# ============================================================ +# Schema query APIs +# ============================================================ + + +def get_json_schema() -> Dict[str, Any]: + """Return a JSON Schema representation of the full configuration. + + This leverages the ``DJConfig`` Pydantic model to automatically + generate a complete JSON Schema including all nested structures. + External agents can use this to discover all available configuration + parameters and their types. + + Returns: + A JSON Schema dict compliant with JSON Schema Draft 2020-12. + """ + return DJConfig.model_json_schema() + + +def get_defaults() -> Dict[str, Any]: + """Return a flat dict of {field_name: default_value} for all fields. + + Nested Pydantic sub-models are included as their model instances + (which can be further serialized with ``.model_dump()``). + """ + return DJConfig().model_dump() + + +# ============================================================ +# Parser registration +# ============================================================ + + +def register_schema_to_parser(parser: ArgumentParser) -> None: + """Register all DJConfig fields onto the given ArgumentParser. + + This bridges the Pydantic model back to jsonargparse so that + ``build_base_parser()`` can remain thin. + + Special handling via ``json_schema_extra``: + - Fields with ``argparse_action`` (e.g. 'store_true') use + ``action=`` instead of ``type=``. + - Fields with ``argparse_nargs`` (e.g. '+') use ``nargs=`` + instead of ``type=``. + - Nested Pydantic sub-models are registered with ``type=`` + pointing to the sub-model class, letting jsonargparse handle + the dot-notation expansion automatically. + - For backward compatibility, nargs fields also register a + kebab-case alias. + """ + from pydantic_core import PydanticUndefined + + for field_name, field_info in DJConfig.model_fields.items(): + extra = field_info.json_schema_extra or {} + action = extra.get(_ARGPARSE_ACTION) + nargs = extra.get(_ARGPARSE_NARGS) + description = field_info.description or "" + + # Resolve the actual default value: Pydantic uses + # PydanticUndefined for fields with default_factory + if field_info.default is not PydanticUndefined: + default = field_info.default + elif field_info.default_factory is not None: + default = field_info.default_factory() + else: + default = None + + kwargs: Dict[str, Any] = {"help": description} + + if action is not None: + # action-based argument (e.g. --debug with store_true) + kwargs["action"] = action + elif nargs is not None: + # nargs-based argument (e.g. --custom_operator_paths) + kwargs["nargs"] = nargs + else: + # standard typed argument — use the annotation directly + kwargs["type"] = field_info.annotation + kwargs["default"] = default + + if action is None and nargs is None: + kwargs["default"] = default + + parser.add_argument(f"--{field_name}", **kwargs) + + # For backward compatibility, also register kebab-case alias + # for nargs fields + if nargs is not None and "_" in field_name: + kebab_name = field_name.replace("_", "-") + if kebab_name != field_name: + alias_kwargs = dict(kwargs) + alias_kwargs["dest"] = field_name + parser.add_argument(f"--{kebab_name}", **alias_kwargs) + + +# ============================================================ +# Post-parse normalization +# ============================================================ + + +def _is_pydantic_model_annotation(annotation) -> bool: + """Check whether *annotation* refers to a Pydantic BaseModel. + + Handles plain ``SubModel`` as well as ``Optional[SubModel]`` + (i.e. ``Union[SubModel, None]``). + """ + if isinstance(annotation, type) and issubclass(annotation, BaseModel): + return True + # Optional[X] is Union[X, None] + args = getattr(annotation, "__args__", None) + if args: + return any(isinstance(a, type) and issubclass(a, BaseModel) for a in args) + return False + + +# Field names whose values are Pydantic sub-models and therefore +# arrive as jsonargparse Namespace objects after parsing. We convert +# them back to plain dicts so that downstream code can keep using +# the established dict-access pattern (e.g. cfg.dataset["configs"]). +_NESTED_MODEL_FIELDS = frozenset( + field_name + for field_name, field_info in DJConfig.model_fields.items() + if _is_pydantic_model_annotation(field_info.annotation) +) + + +def flatten_nested_namespaces(cfg) -> None: + """Convert nested Namespace values back to plain dicts **in-place**. + + After ``jsonargparse`` parses a config that contains Pydantic + sub-models, those fields become ``Namespace`` objects. The rest + of the Data-Juicer codebase expects them to be plain ``dict``s + (e.g. ``cfg.dataset["configs"]``). + + Call this once right before ``init_configs`` returns so that all + downstream consumers see dicts, not Namespaces. + """ + from jsonargparse import Namespace as JAPNamespace + + def _namespace_to_dict_recursive(obj): + """Recursively convert Namespace to dict.""" + if isinstance(obj, JAPNamespace): + result = {} + for key in vars(obj): + result[key] = _namespace_to_dict_recursive(getattr(obj, key)) + return result + if isinstance(obj, list): + return [_namespace_to_dict_recursive(item) for item in obj] + return obj + + for field_name in _NESTED_MODEL_FIELDS: + value = getattr(cfg, field_name, None) + if value is not None and isinstance(value, JAPNamespace): + setattr(cfg, field_name, _namespace_to_dict_recursive(value)) diff --git a/data_juicer/core/data/config_validator.py b/data_juicer/core/data/config_validator.py index 71bcffb0f0..016f4de4f5 100644 --- a/data_juicer/core/data/config_validator.py +++ b/data_juicer/core/data/config_validator.py @@ -10,7 +10,8 @@ class ConfigValidationError(Exception): class ConfigValidator: """Mixin class for configuration validation""" - # Define validation rules for each strategy type + # Define validation rules for each strategy type. + # Subclasses override this with their own rules. CONFIG_VALIDATION_RULES = { "required_fields": [], # Fields that must be present "optional_fields": [], # Fields that are optional @@ -18,18 +19,77 @@ class ConfigValidator: "custom_validators": {}, # Custom validation functions } + def _get_merged_rules(self) -> Dict: + """Merge BASE_CONFIG_RULES (if defined) with CONFIG_VALIDATION_RULES. + + The base rules provide common fields (e.g. ``weight``, ``type``) + that are valid for all strategies. Subclass rules add + strategy-specific fields. The merge is additive: lists are + concatenated (deduplicated) and dicts are shallow-merged with + subclass values taking precedence for the same key. + + Returns: + A merged rules dict with the same structure as + ``CONFIG_VALIDATION_RULES``. + """ + base_rules = getattr(self, "BASE_CONFIG_RULES", None) + sub_rules = self.CONFIG_VALIDATION_RULES + + if not base_rules: + return sub_rules + + def _merge_lists(base_list, sub_list): + """Concatenate two lists, preserving order and deduplicating.""" + seen = set(sub_list) + merged = list(sub_list) + for item in base_list: + if item not in seen: + merged.append(item) + seen.add(item) + return merged + + def _merge_dicts(base_dict, sub_dict): + """Shallow-merge two dicts; sub_dict values win on conflict.""" + merged = dict(base_dict) + merged.update(sub_dict) + return merged + + return { + "required_fields": _merge_lists( + base_rules.get("required_fields", []), + sub_rules.get("required_fields", []), + ), + "optional_fields": _merge_lists( + base_rules.get("optional_fields", []), + sub_rules.get("optional_fields", []), + ), + "field_types": _merge_dicts( + base_rules.get("field_types", {}), + sub_rules.get("field_types", {}), + ), + "custom_validators": _merge_dicts( + base_rules.get("custom_validators", {}), + sub_rules.get("custom_validators", {}), + ), + } + def validate_config(self, ds_config: Dict) -> None: """ Validate the configuration dictionary. + Merges ``BASE_CONFIG_RULES`` (common fields from the base class) + with the subclass ``CONFIG_VALIDATION_RULES`` before validation. + Args: ds_config: Configuration dictionary to validate Raises: - ValidationError: If validation fails + ConfigValidationError: If validation fails """ + rules = self._get_merged_rules() + # Check required fields - missing_fields = [field for field in self.CONFIG_VALIDATION_RULES["required_fields"] if field not in ds_config] + missing_fields = [field for field in rules["required_fields"] if field not in ds_config] if missing_fields: raise ConfigValidationError(f"Missing required fields: {', '.join(missing_fields)}") @@ -37,18 +97,22 @@ def validate_config(self, ds_config: Dict) -> None: # no need for any special checks # Check field types - for field, expected_type in self.CONFIG_VALIDATION_RULES["field_types"].items(): + for field, expected_type in rules["field_types"].items(): if field in ds_config: value = ds_config[field] if not isinstance(value, expected_type): + if isinstance(expected_type, type): + type_name = expected_type.__name__ + elif isinstance(expected_type, tuple): + type_name = " | ".join(t.__name__ for t in expected_type) + else: + type_name = str(expected_type) raise ConfigValidationError( - f"Field '{field}' must be of " - f"type '{expected_type.__name__}', " - f"got '{type(value).__name__}'" + f"Field '{field}' must be of " f"type '{type_name}', " f"got '{type(value).__name__}'" ) # Run custom validators - for field, validator in self.CONFIG_VALIDATION_RULES["custom_validators"].items(): + for field, validator in rules["custom_validators"].items(): if field in ds_config: try: validator(ds_config[field]) diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index 31e2fc46bc..5ce3d25797 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -65,8 +65,10 @@ def __init__(self, cfg: Namespace, executor_type: str = "default"): raise ConfigValidationError('Dataset config should have a "configs" key') if not isinstance(ds_configs["configs"], list) or len(ds_configs["configs"]) == 0: raise ConfigValidationError('Dataset config "configs" should be a non-empty list') - if "max_sample_num" in ds_configs and ( - not isinstance(ds_configs["max_sample_num"], int) or ds_configs["max_sample_num"] <= 0 + if ( + "max_sample_num" in ds_configs + and ds_configs["max_sample_num"] is not None + and (not isinstance(ds_configs["max_sample_num"], int) or ds_configs["max_sample_num"] <= 0) ): raise ConfigValidationError('Dataset config "max_sample_num" should be a positive integer') for ds_config in ds_configs["configs"]: @@ -95,7 +97,7 @@ def __init__(self, cfg: Namespace, executor_type: str = "default"): logger.error(f"No data load strategies found for {ds_configs}") raise ConfigValidationError("No data load strategies found") - # initialzie the sample numbers + # initialize the sample numbers self.max_sample_num = ds_configs.get("max_sample_num", None) # get weights and sample numbers if self.max_sample_num: diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index d07f7184d2..e63e5a2a62 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -52,11 +52,35 @@ class DataLoadStrategy(ABC, ConfigValidator): abstract class for data load strategy """ + # Common fields consumed by the base class for all strategies. + # Subclasses define their own CONFIG_VALIDATION_RULES which will + # be merged with these base rules during validation. + BASE_CONFIG_RULES = { + "optional_fields": ["weight", "type"], + "field_types": {"weight": (int, float), "type": str}, + "field_defaults": {"weight": 1.0}, + "field_descriptions": { + "weight": ( + "Sampling weight for dataset mixing. When multiple " + "datasets are configured, this controls the relative " + "proportion of samples drawn from each source." + ), + "type": ( + "Dataset type that determines which load strategy to " + "use. Common values: 'local', 'huggingface', " + "'modelscope', 'arxiv', 'commoncrawl', 's3'." + ), + }, + } + def __init__(self, ds_config: Dict, cfg: Namespace): self.validate_config(ds_config) self.ds_config = ds_config self.cfg = cfg - self.weight = ds_config.get("weight", 1.0) # default weight is 1.0 + self.weight = ds_config.get( + "weight", + self.BASE_CONFIG_RULES["field_defaults"]["weight"], + ) @abstractmethod def load_data(self, **kwargs) -> DJDataset: diff --git a/data_juicer/tools/DJ_mcp_recipe_flow.py b/data_juicer/tools/DJ_mcp_recipe_flow.py index 81ad109e26..701673faea 100644 --- a/data_juicer/tools/DJ_mcp_recipe_flow.py +++ b/data_juicer/tools/DJ_mcp_recipe_flow.py @@ -20,78 +20,24 @@ def get_global_config_schema() -> dict: """ - Get the full schema of all available global configuration options + Get the full JSON Schema of all available global configuration options for Data-Juicer. - Returns a dictionary where each key is a config parameter name, - and the value is a dict containing: - - type: the expected type of the parameter (e.g. "bool", "int", "str") - - default: the default value - - description: a human-readable description of the parameter + Returns the complete JSON Schema generated from the ``DJConfig`` + Pydantic model. This includes all top-level parameters and nested + sub-structures (e.g. ``dataset.max_sample_num``, + ``checkpoint.strategy``). External agents can use this to discover + every configurable parameter, its type, default value, and + description. Use this tool to discover what configuration options can be passed - to run_data_recipe via the extra_config parameter. This dynamically - reflects the latest Data-Juicer configuration, so it will always - be up-to-date even as new config options are added. + to run_data_recipe via the extra_config parameter. - :returns: A dict mapping config parameter names to their schema info + :returns: A JSON Schema dict (Draft 2020-12 compatible) """ - from data_juicer.config.config import build_base_parser + from data_juicer.config.schema import get_json_schema - parser = build_base_parser() - - if parser is None: - return {"error": "Failed to initialize config parser"} - - # Internal parameters that should not be exposed to users - excluded_params = { - "config", - "auto", - "help", - "print_config", - } - - schema = {} - for action in parser._actions: - # Skip suppressed or internal actions - if not action.option_strings: - continue - - # Use the longest option string as the parameter name - param_name = max(action.option_strings, key=len).lstrip("-") - dest = action.dest - - if dest in excluded_params or param_name in excluded_params: - continue - - # Determine type name - type_name = "str" - if action.type is not None: - if hasattr(action.type, "__name__"): - type_name = action.type.__name__ - elif hasattr(action.type, "__class__"): - type_name = str(action.type) - else: - type_name = str(action.type) - elif isinstance(action.const, bool): - type_name = "bool" - - # Handle choices - choices = None - if action.choices: - choices = list(action.choices) - - entry = { - "type": type_name, - "default": action.default, - "description": action.help or "", - } - if choices: - entry["choices"] = choices - - schema[param_name] = entry - - return schema + return get_json_schema() def get_dataset_load_strategies() -> dict: @@ -109,42 +55,48 @@ def get_dataset_load_strategies() -> dict: field that maps to a data source strategy (e.g., 'local', 'huggingface') - max_sample_num: optional max number of samples to load - Each dataset config dict should follow the required/optional fields - described in the returned strategy information. + Each strategy's returned fields are already merged with common fields + from ``DataLoadStrategy.BASE_CONFIG_RULES`` (e.g. 'weight', 'type'). - :returns: A dict mapping strategy identifiers to their configuration info + :returns: A dict mapping strategy identifiers to their merged + configuration info (including both common and strategy-specific + fields). """ from data_juicer.core.data.load_strategy import DataLoadStrategyRegistry strategies_info = {} - for strategy_key, strategy_class in DataLoadStrategyRegistry._strategies.items(): identifier = f"{strategy_key.executor_type}/" f"{strategy_key.data_type}/" f"{strategy_key.data_source}" - # Extract CONFIG_VALIDATION_RULES if available - validation_rules = getattr(strategy_class, "CONFIG_VALIDATION_RULES", {}) + # Use _get_merged_rules() to get combined base + subclass rules + instance = strategy_class.__new__(strategy_class) + merged_rules = instance._get_merged_rules() - # Extract class docstring description = strategy_class.__doc__ or "" description = description.strip() + # Convert field_types to string representation for serialization + field_types = merged_rules.get("field_types", {}) + field_types_serialized = { + key: ( + " | ".join(t.__name__ for t in val) + if isinstance(val, tuple) + else (val.__name__ if hasattr(val, "__name__") else str(val)) + ) + for key, val in field_types.items() + } + entry = { "executor_type": strategy_key.executor_type, "data_type": strategy_key.data_type, "data_source": strategy_key.data_source, "description": description, "class_name": strategy_class.__name__, + "required_fields": merged_rules.get("required_fields", []), + "optional_fields": merged_rules.get("optional_fields", []), + "field_types": field_types_serialized, } - if validation_rules: - entry["required_fields"] = validation_rules.get("required_fields", []) - entry["optional_fields"] = validation_rules.get("optional_fields", []) - # Convert field_types to string representation for serialization - field_types = validation_rules.get("field_types", {}) - entry["field_types"] = { - key: (val.__name__ if hasattr(val, "__name__") else str(val)) for key, val in field_types.items() - } - strategies_info[identifier] = entry return strategies_info diff --git a/tests/config/test_config.py b/tests/config/test_config.py index ff8036b261..05db7e2e81 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -410,7 +410,7 @@ def test_op_params_parsing(self): self.assertIn(base_param_key, params) def test_get_default_cfg(self): - """Test getting default configuration from config_min.yaml""" + """Test getting default configuration from schema definitions""" # Get default config cfg = get_default_cfg() @@ -422,8 +422,8 @@ def test_get_default_cfg(self): self.assertEqual(cfg.ray_address, 'auto') self.assertEqual(cfg.text_keys, 'text') self.assertEqual(cfg.add_suffix, False) - self.assertEqual(cfg.export_path, './outputs/') - self.assertEqual(cfg.suffixes, None) + self.assertEqual(cfg.export_path, './outputs/hello_world/hello_world.jsonl') + self.assertEqual(cfg.suffixes, []) # Test default values are of correct type self.assertIsInstance(cfg.executor_type, str) diff --git a/tests/core/data/test_dataset_builder.py b/tests/core/data/test_dataset_builder.py index 6f881602e0..a912354209 100644 --- a/tests/core/data/test_dataset_builder.py +++ b/tests/core/data/test_dataset_builder.py @@ -421,7 +421,8 @@ def test_builder_ondisk_config(self): self.assertIsInstance(cfg, Namespace) self.assertEqual(cfg.project_name, 'dataset-local-json') self.assertEqual(cfg.dataset, - {'configs': [{'path': 'sample.jsonl', 'type': 'local'}]}) + {'configs': [{'path': 'sample.jsonl', 'type': 'local'}], + 'max_sample_num': None}) self.assertEqual(not cfg.dataset_path, True) def test_builder_ondisk_config_list(self): @@ -436,7 +437,8 @@ def test_builder_ondisk_config_list(self): {'configs': [ {'path': 'sample.jsonl', 'type': 'local'}, {'path': 'sample.txt', 'type': 'local'} - ]}) + ], + 'max_sample_num': None}) self.assertEqual(not cfg.dataset_path, True) def test_builder_with_max_samples(self): @@ -538,7 +540,8 @@ def test_builder_ray_config(self): 'configs': [{ 'type': 'local', 'path': './test_data/sample.jsonl' - }] + }], + 'max_sample_num': None }) # Create builder and verify From 8dd7f01d4b819e1b5414a8061d743b14e04c8322 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Thu, 9 Apr 2026 15:30:52 +0800 Subject: [PATCH 2/5] test: update invalid values check for max_sample_num --- tests/core/data/test_dataset_builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/core/data/test_dataset_builder.py b/tests/core/data/test_dataset_builder.py index a912354209..1170e6437d 100644 --- a/tests/core/data/test_dataset_builder.py +++ b/tests/core/data/test_dataset_builder.py @@ -505,7 +505,8 @@ def test_mixed_dataset_configs(self): def test_invalid_max_sample_num(self): """Test handling of invalid max_sample_num""" - invalid_values = [-1, 0, "100", None] + # Note: None is a valid value (means no limit), not included here + invalid_values = [-1, 0, "100"] for value in invalid_values: self.base_cfg.dataset = { From 7e4af16e64d66cc3bc389f7f021804b21009f56c Mon Sep 17 00:00:00 2001 From: cmgzn Date: Thu, 9 Apr 2026 16:55:50 +0800 Subject: [PATCH 3/5] feat: add read options for PyArrow in config schema --- data_juicer/config/schema.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/data_juicer/config/schema.py b/data_juicer/config/schema.py index 2ccf9266fd..95930e738b 100644 --- a/data_juicer/config/schema.py +++ b/data_juicer/config/schema.py @@ -298,11 +298,20 @@ class DJConfig(BaseModel): "available options." ), ) + read_options: Dict = Field( + default_factory=dict, + description=( + "Read options passed through to PyArrow reading functions " + "(e.g., block_size for JSON reading). This configuration is " + "especially useful when reading large JSON files." + ), + ) suffixes: List[str] = Field( default_factory=list, description=( - "Suffixes of files that will be find and loaded. If not set, " - "we will find all suffix files under the dataset_path." + "Suffixes of files that will be found and loaded. If not set, " + "we will find all suffix files, and select a suitable formatter " + "with the most files as default." ), ) add_suffix: bool = Field( From 8eeb35348ffd22bda879f4d738ba22f72e7930a0 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Fri, 10 Apr 2026 13:58:55 +0800 Subject: [PATCH 4/5] refactor: improve operator function signature creation by separating parameters with and without defaults --- data_juicer/ops/mapper/llm_extract_mapper.py | 2 +- data_juicer/tools/DJ_mcp_granular_ops.py | 43 ++++++++++++-------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/data_juicer/ops/mapper/llm_extract_mapper.py b/data_juicer/ops/mapper/llm_extract_mapper.py index 84572d2e20..9ef9d540bf 100644 --- a/data_juicer/ops/mapper/llm_extract_mapper.py +++ b/data_juicer/ops/mapper/llm_extract_mapper.py @@ -41,8 +41,8 @@ def __init__( self, input_keys: List[str], output_schema: Dict[str, str], - api_or_hf_model: str = "gpt-4o", *, + api_or_hf_model: str = "gpt-4o", meta_output_key: Optional[str] = MetaKeys.llm_extract, knowledge_grounding_key: Optional[str] = None, knowledge_grounding_fixed: Optional[str] = None, diff --git a/data_juicer/tools/DJ_mcp_granular_ops.py b/data_juicer/tools/DJ_mcp_granular_ops.py index 1ebd897e40..b6dc69d89e 100644 --- a/data_juicer/tools/DJ_mcp_granular_ops.py +++ b/data_juicer/tools/DJ_mcp_granular_ops.py @@ -35,27 +35,36 @@ def create_operator_function(op, mcp): docstring = op["desc"] param_docstring = op["param_desc"] - # Create new function signature with dataset_path as first parameter - # Consider adding other common parameters later, such as export_psth - new_parameters = [ - inspect.Parameter("dataset_path", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str), - inspect.Parameter( - "export_path", - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=Optional[str], - default=None, - ), - inspect.Parameter( - "np", - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=Optional[int], - default=None, - ), - ] + [ + # Separate operator parameters into those with and without defaults + operator_params = [ process_parameter(name, param) for name, param in sig.parameters.items() if name not in ("args", "kwargs", "self") ] + params_no_default = [p for p in operator_params if p.default == inspect.Parameter.empty] + params_with_default = [p for p in operator_params if p.default != inspect.Parameter.empty] + + # Create new function signature with dataset_path first (no default) + # followed by operator params without defaults, then optional params with defaults + new_parameters = ( + [inspect.Parameter("dataset_path", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str)] + + params_no_default + + [ + inspect.Parameter( + "export_path", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=Optional[str], + default=None, + ), + inspect.Parameter( + "np", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=Optional[int], + default=None, + ), + ] + + params_with_default + ) new_signature = sig.replace(parameters=new_parameters, return_annotation=str) def func(*args, **kwargs): From 4b1d3cee7c91d0b58cfd89bda05078347fa7b5c9 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Wed, 15 Apr 2026 14:59:51 +0800 Subject: [PATCH 5/5] refactor: simplify nested model fields handling logic --- data_juicer/config/schema.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/data_juicer/config/schema.py b/data_juicer/config/schema.py index 95930e738b..40411339cf 100644 --- a/data_juicer/config/schema.py +++ b/data_juicer/config/schema.py @@ -906,27 +906,23 @@ def _is_pydantic_model_annotation(annotation) -> bool: return False -# Field names whose values are Pydantic sub-models and therefore -# arrive as jsonargparse Namespace objects after parsing. We convert -# them back to plain dicts so that downstream code can keep using -# the established dict-access pattern (e.g. cfg.dataset["configs"]). -_NESTED_MODEL_FIELDS = frozenset( - field_name - for field_name, field_info in DJConfig.model_fields.items() - if _is_pydantic_model_annotation(field_info.annotation) -) +# Fields that must be converted from jsonargparse Namespace to plain dict +# after parsing, because downstream code uses isinstance(x, dict) checks. +# Other nested sub-model fields (partition, checkpoint, event_logging, etc.) +# are kept as Namespace objects, which natively support both attribute access +# (cfg.partition.target_size_mb) and dict-style access (cfg.partition["mode"]). +_DICT_CONVERT_FIELDS = frozenset({"dataset"}) def flatten_nested_namespaces(cfg) -> None: - """Convert nested Namespace values back to plain dicts **in-place**. + """Convert selected nested Namespace values back to plain dicts **in-place**. After ``jsonargparse`` parses a config that contains Pydantic - sub-models, those fields become ``Namespace`` objects. The rest - of the Data-Juicer codebase expects them to be plain ``dict``s - (e.g. ``cfg.dataset["configs"]``). - - Call this once right before ``init_configs`` returns so that all - downstream consumers see dicts, not Namespaces. + sub-models, those fields become ``Namespace`` objects. Only fields + listed in ``_DICT_CONVERT_FIELDS`` are converted to plain dicts + (because downstream code relies on ``isinstance(x, dict)`` checks). + All other nested fields are left as ``Namespace`` objects, which + natively support both attribute and dict-style access. """ from jsonargparse import Namespace as JAPNamespace @@ -941,7 +937,7 @@ def _namespace_to_dict_recursive(obj): return [_namespace_to_dict_recursive(item) for item in obj] return obj - for field_name in _NESTED_MODEL_FIELDS: + for field_name in _DICT_CONVERT_FIELDS: value = getattr(cfg, field_name, None) if value is not None and isinstance(value, JAPNamespace): setattr(cfg, field_name, _namespace_to_dict_recursive(value))