Skip to content

Commit 2ab6da7

Browse files
authored
merge online and offline into one script (#308)
* merge online and offline into one script * polish --------- Co-authored-by: root <FrankLeeeee>
1 parent 2de6996 commit 2ab6da7

28 files changed

Lines changed: 233 additions & 1813 deletions

docs/advanced_features/customization.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
torchrun \
77
--standalone \
88
--nproc_per_node 8 \
9-
./scripts/train_eagle3_online.py \
9+
./scripts/train_eagle3.py \
1010
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
1111
--draft-model-config ./configs/llama3-8B-eagle3.json \
1212
--train-data-path ./cache/dataset/sharegpt.jsonl \
@@ -19,7 +19,7 @@ torchrun \
1919
--cache-dir ./cache
2020
```
2121

22-
If you wish to understand what each argument does, you can run `python scripts/train_eagle3_online.py --help` to see the full list of arguments. Particularly, we will discuss some important arguments below.
22+
If you wish to understand what each argument does, you can run `python scripts/train_eagle3.py --help` to see the full list of arguments. Particularly, we will discuss some important arguments below.
2323
- `--chat-template`: This should be the chat template to use for the model, so please make sure you set it to the correct value.
2424
- `--cache-dir`: This directory contains the dataset cache including the `input_ids`, `loss_mask`, `attention_mask` and `vocab_mapping`. These caches can make your data loading much faster once a cache is generated. The cache file has a name which is obtained by hashing the dataset path to avoid cache collision.
2525

docs/basic_usage/data_preparation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ To use pre-formatted datasets, add the `--is-preformatted` flag to your training
100100

101101
```bash
102102
torchrun --standalone --nproc_per_node 8 \
103-
scripts/train_eagle3_online.py \
103+
scripts/train_eagle3.py \
104104
--is-preformatted \
105105
--chat-template qwen \
106106
--train-data-path ./your_preformatted_dataset.jsonl \

examples/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels \
1010
torchrun \
1111
--standalone \
1212
--nproc_per_node $NUM_GPUS \
13-
$ROOT_DIR/scripts/train_eagle3_online.py \
13+
$ROOT_DIR/scripts/train_eagle3.py \
1414
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
1515
--draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \
1616
--train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \

examples/prepare_hidden_states.sh

Lines changed: 0 additions & 30 deletions
This file was deleted.

examples/run_deepseek_v2_lite_eagle3_online.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@ NUM_GPUS=${1:-8}
77
torchrun \
88
--standalone \
99
--nproc_per_node $NUM_GPUS \
10-
$ROOT_DIR/scripts/train_eagle3_online.py \
11-
--target-model-path DeepSeek-V2-Lite \
10+
$ROOT_DIR/scripts/train_eagle3.py \
11+
--target-model-path deepseek-ai/DeepSeek-V2-Lite \
1212
--draft-model-config $ROOT_DIR/configs/deepseek-v2-lite-eagle3.json \
13-
--train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
14-
--output-dir $ROOT_DIR/outputs/deepseek-v2-lite-eagle3 \
13+
--train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
14+
--output-dir $ROOT_DIR/outputs/deepseek-v2-lite-eagle3-sharegpt \
1515
--num-epochs 10 \
1616
--batch-size 1 \
1717
--tp-size 1 \
1818
--learning-rate 1e-4 \
19-
--max-length 2048 \
19+
--max-length 4096 \
2020
--chat-template deepseek \
21-
--cache-dir $ROOT_DIR/cache \
21+
--target-model-backend hf \
22+
--cache-dir $ROOT_DIR/cache

examples/run_gpt_oss_120b_eagle3_online.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ NUM_GPUS=${1:-8}
77
torchrun \
88
--standalone \
99
--nproc_per_node $NUM_GPUS \
10-
$ROOT_DIR/scripts/train_eagle3_online.py \
10+
$ROOT_DIR/scripts/train_eagle3.py \
1111
--target-model-path openai/gpt-oss-120b \
1212
--draft-model-config $ROOT_DIR/configs/gpt-oss-20B-eagle3.json \
1313
--train-data-path $ROOT_DIR/cache/dataset/perfect-blend-gptoss-20B.jsonl \
@@ -16,7 +16,7 @@ torchrun \
1616
--num-epochs 10 \
1717
--batch-size 1 \
1818
--learning-rate 1e-4 \
19-
--max-length 2048 \
19+
--max-length 4096 \
2020
--chat-template gpt-oss \
2121
--cache-dir $ROOT_DIR/cache \
2222
--dist-timeout 60

examples/run_gpt_oss_120b_eagle3_sgl_online.sh

Lines changed: 0 additions & 112 deletions
This file was deleted.

examples/run_gpt_oss_20b_eagle3_online.sh

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,15 @@ NUM_GPUS=${1:-8}
77
torchrun \
88
--standalone \
99
--nproc_per_node $NUM_GPUS \
10-
$ROOT_DIR/scripts/train_eagle3_online.py \
10+
$ROOT_DIR/scripts/train_eagle3.py \
1111
--target-model-path openai/gpt-oss-20b \
1212
--draft-model-config $ROOT_DIR/configs/gpt-oss-20B-eagle3.json \
1313
--train-data-path $ROOT_DIR/cache/dataset/perfect-blend-gptoss-20B.jsonl \
1414
--output-dir $ROOT_DIR/outputs/perfect-blend-gptoss-20b-eagle3 \
1515
--num-epochs 10 \
1616
--batch-size 1 \
1717
--learning-rate 1e-4 \
18-
--max-length 2048 \
18+
--max-length 4096 \
1919
--chat-template gpt-oss \
2020
--cache-dir $ROOT_DIR/cache \
2121
--dist-timeout 60
22-
23-
24-
# --train-data-path $ROOT_DIR/cache/dataset/perfect-blend-gptoss-20B.jsonl \

examples/run_llama3.3_eagle3_online.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ NUM_GPUS=${1:-8}
77
torchrun \
88
--standalone \
99
--nproc_per_node $NUM_GPUS \
10-
$ROOT_DIR/scripts/train_eagle3_online.py \
10+
$ROOT_DIR/scripts/train_eagle3.py \
1111
--target-model-path meta-llama/Llama-3.3-70B-Instruct \
1212
--draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \
1313
--train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
@@ -16,6 +16,6 @@ torchrun \
1616
--batch-size 1 \
1717
--tp-size 8 \
1818
--learning-rate 1e-4 \
19-
--max-length 2048 \
19+
--max-length 4096 \
2020
--chat-template llama3 \
2121
--cache-dir $ROOT_DIR/cache
Lines changed: 28 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,35 @@
1-
#!/bin/bash
2-
export PERSIST_DIR=/tmp # Please Change this to your own directory
3-
export MODEL_PATH="meta-llama/Llama-3.1-8B-Instruct"
4-
export DATASET_PATH=$PERSIST_DIR/dataset/
5-
export CACHE_DIR=$PERSIST_DIR/cache/
6-
export OUTPUT_DIR=$PERSIST_DIR/outputs/
7-
export HIDDEN_STATES_DIR=$PERSIST_DIR/hidden_states/
8-
export MAX_LENGTH=2048
9-
export CHAT_TEMPLATE=llama3
1+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
2+
ROOT_DIR=$(dirname $SCRIPT_DIR)
3+
NUM_GPUS=${1:-8}
104

11-
hf download $MODEL_PATH
12-
hf download Aeala/ShareGPT_Vicuna_unfiltered --repo-type dataset
13-
14-
python scripts/prepare_data.py --dataset sharegpt --output-path $DATASET_PATH --split-eval
15-
python scripts/build_eagle3_dataset.py \
16-
--model-path $MODEL_PATH \
17-
--data-path $DATASET_PATH \
18-
--cache-dir $CACHE_DIR \
19-
--chat-template $CHAT_TEMPLATE \
20-
--max-length $MAX_LENGTH \
21-
22-
CUDA_VISIBLE_DEVICES=1,2,3,4 torchrun --nproc_per_node=4 \
23-
scripts/prepare_hidden_states.py \
24-
--data-path $DATASET_PATH/sharegpt_test.jsonl \
25-
--model-path $MODEL_PATH \
26-
--cache-dir $CACHE_DIR \
27-
--output-path $HIDDEN_STATES_DIR/sharegpt_test \
28-
--chat-template $CHAT_TEMPLATE \
29-
--max-length $MAX_LENGTH \
30-
--enable-aux-hidden-states \
31-
--tp-size 4 \
32-
--batch-size 4 \
33-
--mem-frac=0.75
34-
35-
CUDA_VISIBLE_DEVICES=1,2,3,4 torchrun --nproc_per_node=4 \
5+
# generate hidden states
6+
torchrun \
7+
--standalone \
8+
--nproc_per_node $NUM_GPUS \
369
scripts/prepare_hidden_states.py \
37-
--data-path $DATASET_PATH/sharegpt_train.jsonl \
38-
--model-path $MODEL_PATH \
39-
--cache-dir $CACHE_DIR \
40-
--output-path $HIDDEN_STATES_DIR/sharegpt_train \
41-
--chat-template $CHAT_TEMPLATE \
42-
--max-length $MAX_LENGTH \
10+
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
4311
--enable-aux-hidden-states \
44-
--tp-size 4 \
45-
--batch-size 4 \
46-
--mem-frac=0.75
47-
48-
# python scripts/view_data.py --data-path $HIDDEN_STATES_DIR/all_test/rows_0-5000/data_100.ckpt --tokenizer $MODEL_PATH
49-
# python scripts/view_data.py --data-path $HIDDEN_STATES_DIR/all_train/rows_0-5000/data_100.ckpt --tokenizer $MODEL_PATH
12+
--data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
13+
--output-path $ROOT_DIR/cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \
14+
--chat-template llama3 \
15+
--max-length 4096 \
16+
--tp-size 1 \
17+
--batch-size 32
5018

51-
export NUM_GPUS=4
52-
CUDA_VISIBLE_DEVICES=1,2,3,4 torchrun \
19+
# train eagle3 offline
20+
torchrun \
5321
--standalone \
5422
--nproc_per_node $NUM_GPUS \
55-
scripts/train_eagle3_offline.py \
56-
--target-model-path $MODEL_PATH \
57-
--draft-model-config ./configs/llama3-8B-eagle3.json \
58-
--train-data-path $DATASET_PATH/sharegpt_train.jsonl \
59-
--train-hidden-states-path $HIDDEN_STATES_DIR/sharegpt_train/ \
60-
--eval-data-path $DATASET_PATH/sharegpt_test.jsonl \
61-
--eval-hidden-states-path $HIDDEN_STATES_DIR/sharegpt_test/ \
62-
--output-dir $OUTPUT_DIR \
23+
$ROOT_DIR/scripts/train_eagle3.py \
24+
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
25+
--draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \
26+
--train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
27+
--train-hidden-states-path $ROOT_DIR/cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \
28+
--output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-sharegpt-offline \
6329
--num-epochs 10 \
64-
--draft-global-batch-size 16 \
65-
--draft-micro-batch-size 1 \
66-
--learning-rate 5e-5 \
67-
--draft-attention-backend flex_attention \
68-
--max-length $MAX_LENGTH \
69-
--chat-template $CHAT_TEMPLATE \
70-
--cache-dir $CACHE_DIR \
71-
--dist-timeout=10 \
72-
--log-steps 1 \
73-
--report-to wandb \
74-
--wandb-project llama3-8b-eagle3 \
75-
--wandb-name offline-100k-4gpus
30+
--batch-size 1 \
31+
--tp-size 1 \
32+
--learning-rate 1e-4 \
33+
--max-length 4096 \
34+
--chat-template llama3 \
35+
--cache-dir $ROOT_DIR/cache

0 commit comments

Comments
 (0)