diff --git a/configs/step-3.5-flash-eagle3.json b/configs/step-3.5-flash-eagle3.json new file mode 100644 index 00000000..3bd33d5c --- /dev/null +++ b/configs/step-3.5-flash-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [ + 4, + 20, + 40 + ], + "use_aux_hidden_state": true + }, + "hidden_size": 4096, + "num_attention_heads": 64, + "num_key_value_heads": 8, + "intermediate_size": 11264, + "hidden_act": "silu", + "max_position_embeddings": 262144, + "vocab_size": 128896, + "draft_vocab_size": 32000, + "num_hidden_layers": 1, + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 2, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "use_cache": true, + "model_type": "llama", + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0" +} diff --git a/examples/run_step3p5_flash_eagle3_online.sh b/examples/run_step3p5_flash_eagle3_online.sh new file mode 100644 index 00000000..58093d2f --- /dev/null +++ b/examples/run_step3p5_flash_eagle3_online.sh @@ -0,0 +1,36 @@ + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for step-3.5-flash +NUM_GPUS=${1:-8} +TP_SIZE=${2:-4} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# train eagle3 online +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path stepfun-ai/Step-3.5-Flash \ + --draft-model-config configs/step-3.5-flash-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/ultrachat_train_regen.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/step-3.5-flash-eagle3-ultrachat-regen-online \ + --tp-size $TP_SIZE \ + --sglang-ep-size $TP_SIZE \ + --target-model-backend sglang \ + --trust-remote-code \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 5e-5 \ + --max-length 4096 \ + --sglang-attention-backend fa3 \ + --chat-template step3.5 \ + --cache-dir $ROOT_DIR/cache \ + --dist-timeout 60 \ + --sglang-mem-fraction-static 0.75 \ + --report-to wandb \ + --wandb-project specforge-step3p5-flash \ + --wandb-name specforge-step3p5-flash-ultrachat-regen diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 9f2c7b24..4b9fbdf0 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -47,6 +47,7 @@ def parse_args(): "perfectblend-llama4-scout-instruct", "perfectblend-llama4-maverick-instruct", "magpie-qwen2.5-pro-1m-v0.1", + "smoltalk-chinese", "sharegpt4v", "allava4v", "opc", @@ -167,6 +168,31 @@ def process_ultrachat_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, in return row, 0 +def process_smoltalk_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """Process a row from the opencsg/smoltalk-chinese dataset. + + The function expects a row with the following schema: + { + "conversations": [ + { + "role": "user" | "assistant", + "content": str + } + ] + } + """ + conversations = row[ + "conversations" + ] # smoltalk uses "conversations", not "messages" + formatted_conversations = [] + for message in conversations: + role = message["role"] # already "user" or "assistant" — no mapping needed + content = message["content"] + assert role in ["user", "assistant"] + formatted_conversations.append({"role": role, "content": content}) + return {"id": row["id"], "conversations": formatted_conversations}, 0 + + def process_sharegpt_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: """ sharegpt dataset schema: @@ -575,6 +601,9 @@ def main(): ds = load_dataset("Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1")["train"] ds = ds.rename_column("uuid", "id") proc_fn = process_sharegpt_row + elif args.dataset == "smoltalk-chinese": + ds = load_dataset("zjxia/smoltalk-chinese")["train"] + proc_fn = process_smoltalk_row elif args.dataset == "sharegpt4v": ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")["train"] raise Exception("Not supported sharegpt4v now") diff --git a/specforge/args.py b/specforge/args.py index 2cd5efc3..2f409646 100644 --- a/specforge/args.py +++ b/specforge/args.py @@ -194,7 +194,7 @@ def from_args(args: argparse.Namespace) -> "SGLangBackendArgs": args.target_batch_size if hasattr(args, "target_batch_size") else None ), sglang_max_total_tokens=( - args.target_batch_size * args.max_length + int(args.target_batch_size * args.max_length * 1.2) if hasattr(args, "target_batch_size") and hasattr(args, "max_length") else None ), diff --git a/specforge/data/template.py b/specforge/data/template.py index 4dde000f..a07184dc 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -129,6 +129,18 @@ def get_all_template_names(self) -> List[str]: ), ) +TEMPLATE_REGISTRY.register( + name="step3.5", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + parser_type="thinking", + enable_thinking=True, + ), +) + TEMPLATE_REGISTRY.register( name="phi3", template=ChatTemplate(