|
1 | | -#!/bin/bash |
2 | | -export PERSIST_DIR=/tmp # Please Change this to your own directory |
3 | | -export MODEL_PATH="meta-llama/Llama-3.1-8B-Instruct" |
4 | | -export DATASET_PATH=$PERSIST_DIR/dataset/ |
5 | | -export CACHE_DIR=$PERSIST_DIR/cache/ |
6 | | -export OUTPUT_DIR=$PERSIST_DIR/outputs/ |
7 | | -export HIDDEN_STATES_DIR=$PERSIST_DIR/hidden_states/ |
8 | | -export MAX_LENGTH=2048 |
9 | | -export CHAT_TEMPLATE=llama3 |
| 1 | +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) |
| 2 | +ROOT_DIR=$(dirname $SCRIPT_DIR) |
| 3 | +NUM_GPUS=${1:-8} |
10 | 4 |
|
11 | | -hf download $MODEL_PATH |
12 | | -hf download Aeala/ShareGPT_Vicuna_unfiltered --repo-type dataset |
13 | | - |
14 | | -python scripts/prepare_data.py --dataset sharegpt --output-path $DATASET_PATH --split-eval |
15 | | -python scripts/build_eagle3_dataset.py \ |
16 | | - --model-path $MODEL_PATH \ |
17 | | - --data-path $DATASET_PATH \ |
18 | | - --cache-dir $CACHE_DIR \ |
19 | | - --chat-template $CHAT_TEMPLATE \ |
20 | | - --max-length $MAX_LENGTH \ |
21 | | - |
22 | | -CUDA_VISIBLE_DEVICES=1,2,3,4 torchrun --nproc_per_node=4 \ |
23 | | - scripts/prepare_hidden_states.py \ |
24 | | - --data-path $DATASET_PATH/sharegpt_test.jsonl \ |
25 | | - --model-path $MODEL_PATH \ |
26 | | - --cache-dir $CACHE_DIR \ |
27 | | - --output-path $HIDDEN_STATES_DIR/sharegpt_test \ |
28 | | - --chat-template $CHAT_TEMPLATE \ |
29 | | - --max-length $MAX_LENGTH \ |
30 | | - --enable-aux-hidden-states \ |
31 | | - --tp-size 4 \ |
32 | | - --batch-size 4 \ |
33 | | - --mem-frac=0.75 |
34 | | - |
35 | | -CUDA_VISIBLE_DEVICES=1,2,3,4 torchrun --nproc_per_node=4 \ |
| 5 | +# generate hidden states |
| 6 | +torchrun \ |
| 7 | + --standalone \ |
| 8 | + --nproc_per_node $NUM_GPUS \ |
36 | 9 | scripts/prepare_hidden_states.py \ |
37 | | - --data-path $DATASET_PATH/sharegpt_train.jsonl \ |
38 | | - --model-path $MODEL_PATH \ |
39 | | - --cache-dir $CACHE_DIR \ |
40 | | - --output-path $HIDDEN_STATES_DIR/sharegpt_train \ |
41 | | - --chat-template $CHAT_TEMPLATE \ |
42 | | - --max-length $MAX_LENGTH \ |
| 10 | + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ |
43 | 11 | --enable-aux-hidden-states \ |
44 | | - --tp-size 4 \ |
45 | | - --batch-size 4 \ |
46 | | - --mem-frac=0.75 |
47 | | - |
48 | | -# python scripts/view_data.py --data-path $HIDDEN_STATES_DIR/all_test/rows_0-5000/data_100.ckpt --tokenizer $MODEL_PATH |
49 | | -# python scripts/view_data.py --data-path $HIDDEN_STATES_DIR/all_train/rows_0-5000/data_100.ckpt --tokenizer $MODEL_PATH |
| 12 | + --data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ |
| 13 | + --output-path $ROOT_DIR/cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \ |
| 14 | + --chat-template llama3 \ |
| 15 | + --max-length 4096 \ |
| 16 | + --tp-size 1 \ |
| 17 | + --batch-size 32 |
50 | 18 |
|
51 | | -export NUM_GPUS=4 |
52 | | -CUDA_VISIBLE_DEVICES=1,2,3,4 torchrun \ |
| 19 | +# train eagle3 offline |
| 20 | +torchrun \ |
53 | 21 | --standalone \ |
54 | 22 | --nproc_per_node $NUM_GPUS \ |
55 | | - scripts/train_eagle3_offline.py \ |
56 | | - --target-model-path $MODEL_PATH \ |
57 | | - --draft-model-config ./configs/llama3-8B-eagle3.json \ |
58 | | - --train-data-path $DATASET_PATH/sharegpt_train.jsonl \ |
59 | | - --train-hidden-states-path $HIDDEN_STATES_DIR/sharegpt_train/ \ |
60 | | - --eval-data-path $DATASET_PATH/sharegpt_test.jsonl \ |
61 | | - --eval-hidden-states-path $HIDDEN_STATES_DIR/sharegpt_test/ \ |
62 | | - --output-dir $OUTPUT_DIR \ |
| 23 | + $ROOT_DIR/scripts/train_eagle3.py \ |
| 24 | + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ |
| 25 | + --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ |
| 26 | + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ |
| 27 | + --train-hidden-states-path $ROOT_DIR/cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \ |
| 28 | + --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-sharegpt-offline \ |
63 | 29 | --num-epochs 10 \ |
64 | | - --draft-global-batch-size 16 \ |
65 | | - --draft-micro-batch-size 1 \ |
66 | | - --learning-rate 5e-5 \ |
67 | | - --draft-attention-backend flex_attention \ |
68 | | - --max-length $MAX_LENGTH \ |
69 | | - --chat-template $CHAT_TEMPLATE \ |
70 | | - --cache-dir $CACHE_DIR \ |
71 | | - --dist-timeout=10 \ |
72 | | - --log-steps 1 \ |
73 | | - --report-to wandb \ |
74 | | - --wandb-project llama3-8b-eagle3 \ |
75 | | - --wandb-name offline-100k-4gpus |
| 30 | + --batch-size 1 \ |
| 31 | + --tp-size 1 \ |
| 32 | + --learning-rate 1e-4 \ |
| 33 | + --max-length 4096 \ |
| 34 | + --chat-template llama3 \ |
| 35 | + --cache-dir $ROOT_DIR/cache |
0 commit comments