Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions configs/step-3.5-flash-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"eagle_config": {
"eagle_aux_hidden_state_layer_ids": [
4,
20,
40
],
"use_aux_hidden_state": true
},
"hidden_size": 4096,
"num_attention_heads": 64,
"num_key_value_heads": 8,
"intermediate_size": 11264,
"hidden_act": "silu",
"max_position_embeddings": 262144,
"vocab_size": 128896,
"draft_vocab_size": 32000,
"num_hidden_layers": 1,
"bos_token_id": 0,
"eos_token_id": 1,
"pad_token_id": 2,
"rms_norm_eps": 1e-05,
"tie_word_embeddings": false,
"use_cache": true,
"model_type": "llama",
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0"
}
36 changes: 36 additions & 0 deletions examples/run_step3p5_flash_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels

# train eagle3 for step-3.5-flash
NUM_GPUS=${1:-8}
TP_SIZE=${2:-4}
BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64}

# train eagle3 online
torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path stepfun-ai/Step-3.5-Flash \
--draft-model-config configs/step-3.5-flash-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/ultrachat_train_regen.jsonl \
--build-dataset-num-proc $BUILD_DATASET_NUM_PROC \
--output-dir $ROOT_DIR/outputs/step-3.5-flash-eagle3-ultrachat-regen-online \
--tp-size $TP_SIZE \
--sglang-ep-size $TP_SIZE \
--target-model-backend sglang \
--trust-remote-code \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 5e-5 \
--max-length 4096 \
--sglang-attention-backend fa3 \
--chat-template step3.5 \
--cache-dir $ROOT_DIR/cache \
--dist-timeout 60 \
--sglang-mem-fraction-static 0.75 \
--report-to wandb \
--wandb-project specforge-step3p5-flash \
--wandb-name specforge-step3p5-flash-ultrachat-regen
29 changes: 29 additions & 0 deletions scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def parse_args():
"perfectblend-llama4-scout-instruct",
"perfectblend-llama4-maverick-instruct",
"magpie-qwen2.5-pro-1m-v0.1",
"smoltalk-chinese",
"sharegpt4v",
"allava4v",
"opc",
Expand Down Expand Up @@ -167,6 +168,31 @@ def process_ultrachat_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, in
return row, 0


def process_smoltalk_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]:
"""Process a row from the opencsg/smoltalk-chinese dataset.

The function expects a row with the following schema:
{
"conversations": [
{
"role": "user" | "assistant",
"content": str
}
]
}
"""
conversations = row[
"conversations"
] # smoltalk uses "conversations", not "messages"
formatted_conversations = []
for message in conversations:
role = message["role"] # already "user" or "assistant" — no mapping needed
content = message["content"]
assert role in ["user", "assistant"]
formatted_conversations.append({"role": role, "content": content})
return {"id": row["id"], "conversations": formatted_conversations}, 0


def process_sharegpt_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]:
"""
sharegpt dataset schema:
Expand Down Expand Up @@ -575,6 +601,9 @@ def main():
ds = load_dataset("Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1")["train"]
ds = ds.rename_column("uuid", "id")
proc_fn = process_sharegpt_row
elif args.dataset == "smoltalk-chinese":
ds = load_dataset("zjxia/smoltalk-chinese")["train"]
proc_fn = process_smoltalk_row
elif args.dataset == "sharegpt4v":
ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")["train"]
raise Exception("Not supported sharegpt4v now")
Expand Down
2 changes: 1 addition & 1 deletion specforge/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def from_args(args: argparse.Namespace) -> "SGLangBackendArgs":
args.target_batch_size if hasattr(args, "target_batch_size") else None
),
sglang_max_total_tokens=(
args.target_batch_size * args.max_length
int(args.target_batch_size * args.max_length * 1.2)
if hasattr(args, "target_batch_size") and hasattr(args, "max_length")
else None
),
Expand Down
12 changes: 12 additions & 0 deletions specforge/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ def get_all_template_names(self) -> List[str]:
),
)

TEMPLATE_REGISTRY.register(
name="step3.5",
template=ChatTemplate(
assistant_header="<|im_start|>assistant\n",
user_header="<|im_start|>user\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|im_end|>\n",
parser_type="thinking",
enable_thinking=True,
),
)

TEMPLATE_REGISTRY.register(
name="phi3",
template=ChatTemplate(
Expand Down