Skip to content

Commit 3df5b27

Browse files
authored
added tests for scripts (#331)
* added tests for scripts * added tests for scripts * polish * polish * polish * polish * polish * polish * polish * added tests for scripts * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish
1 parent 70f5187 commit 3df5b27

8 files changed

Lines changed: 347 additions & 2 deletions

File tree

.github/workflows/test.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@ jobs:
2626

2727
- name: Restore cache
2828
run: |
29+
if [ -d /github/home/cache ] && [ ! -z "$(ls -A /github/home/cache/)" ]; then
30+
cp -p -r /github/home/cache ./
31+
fi
32+
2933
if [ -d /github/home/sf ] && [ ! -z "$(ls -A /github/home/sf/)" ]; then
30-
cp -p -r /github/home/sf/* ./
34+
cp -p -r /github/home/sf ./
3135
fi
3236
3337
- name: Remove flashinfer # this is needed to avoid flashinfer jit compilation makes the program hang
@@ -55,3 +59,4 @@ jobs:
5559
- name: Save cache
5660
run: |
5761
cp -p -r sf /github/home/
62+
cp -p -r cache /github/home/

scripts/regenerate_train_data.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def parse_arguments():
8383
action="store_true",
8484
help="Whether the model is a GPT-OSS model",
8585
)
86+
parser.add_argument(
87+
"--num-samples",
88+
type=int,
89+
default=None,
90+
help="The number of samples to regenerate, if not provided, all samples will be regenerated",
91+
)
8692
return parser.parse_args()
8793

8894

@@ -217,6 +223,9 @@ def main():
217223
)
218224
print("-" * 50)
219225

226+
success_samples = 0
227+
error_samples = 0
228+
220229
# Create progress bar
221230
with open(args.input_file_path, "r") as input_file, open(
222231
args.output_file_path, "w"
@@ -231,6 +240,12 @@ def main():
231240
start_server_index = 0
232241

233242
for line in input_file:
243+
if (
244+
args.num_samples is not None
245+
and success_samples + error_samples >= args.num_samples
246+
):
247+
break
248+
234249
data = json.loads(line.strip())
235250

236251
# find server address with the least waiting requests
@@ -249,10 +264,12 @@ def main():
249264
error_file_handle.write(
250265
json.dumps(regen_data, ensure_ascii=False) + "\n"
251266
)
267+
error_samples += 1
252268
else:
253269
output_file_handle.write(
254270
json.dumps(regen_data, ensure_ascii=False) + "\n"
255271
)
272+
success_samples += 1
256273
waiting_queue[server_address].remove(req_future)
257274
finished_on_request = True
258275

@@ -280,12 +297,16 @@ def main():
280297
error_file_handle.write(
281298
json.dumps(regen_data, ensure_ascii=False) + "\n"
282299
)
300+
error_samples += 1
283301
else:
284302
output_file_handle.write(
285303
json.dumps(regen_data, ensure_ascii=False) + "\n"
286304
)
305+
success_samples += 1
287306

288-
print(f"\nProcessing completed!")
307+
print(
308+
f"\nProcessing completed! {success_samples} samples regenerated, {error_samples} samples failed."
309+
)
289310

290311

291312
if __name__ == "__main__":

scripts/train_eagle3.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
9191

9292
# training hyper params
9393
parser.add_argument("--num-epochs", type=int, default=10)
94+
parser.add_argument(
95+
"--max-num-steps",
96+
type=int,
97+
default=None,
98+
help="The maximum number of steps to train. If not provided, will be calculated as num_epochs * steps_per_epoch",
99+
)
94100
parser.add_argument("--batch-size", type=int, default=1)
95101
parser.add_argument("--learning-rate", type=float, default=1e-4)
96102
parser.add_argument("--max-length", type=int, default=2048)
@@ -766,6 +772,12 @@ def main():
766772
# Save the model
767773
save_checkpoints(args, epoch, global_step, eagle3_model, optimizer)
768774

775+
if args.max_num_steps is not None and global_step >= args.max_num_steps:
776+
break
777+
778+
if args.max_num_steps is not None and global_step >= args.max_num_steps:
779+
break
780+
769781
# Close the tracker
770782
tracker.close()
771783
destroy_distributed()

tests/test_scripts/__init__.py

Whitespace-only changes.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import unittest
2+
from pathlib import Path
3+
4+
from sglang.utils import execute_shell_command
5+
6+
CACHE_DIR = Path(__file__).parent.parent.parent.joinpath("cache")
7+
8+
9+
class TestPrepareData(unittest.TestCase):
10+
11+
def test_prepare_sharegpt(self):
12+
sharegpt_train_path = CACHE_DIR.joinpath("dataset", "sharegpt_train.jsonl")
13+
14+
if sharegpt_train_path.exists():
15+
# delete the file
16+
sharegpt_train_path.unlink()
17+
process = execute_shell_command(
18+
"python scripts/prepare_data.py --dataset sharegpt"
19+
)
20+
process.wait()
21+
self.assertEqual(process.returncode, 0)
22+
self.assertTrue(sharegpt_train_path.exists())
23+
24+
25+
if __name__ == "__main__":
26+
unittest.main(verbosity=2)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import unittest
2+
from pathlib import Path
3+
4+
from tests.utils import execute_shell_command, wait_for_server
5+
6+
CACHE_DIR = Path(__file__).parent.parent.parent.joinpath("cache")
7+
8+
9+
class TestRegenerateTrainData(unittest.TestCase):
10+
11+
def test_regenerate_sharegpt(self):
12+
# prepare data
13+
data_process = execute_shell_command(
14+
"python scripts/prepare_data.py --dataset sharegpt"
15+
)
16+
data_process.wait()
17+
18+
# launch sglang
19+
sglang_process = execute_shell_command(
20+
"""python3 -m sglang.launch_server \
21+
--model unsloth/Llama-3.2-1B-Instruct \
22+
--tp 1 \
23+
--cuda-graph-bs 4 \
24+
--dtype bfloat16 \
25+
--mem-frac=0.8 \
26+
--port 30000
27+
""",
28+
disable_proxy=True,
29+
enable_hf_mirror=True,
30+
)
31+
wait_for_server(f"http://localhost:30000", disable_proxy=True)
32+
33+
regeneration_process = execute_shell_command(
34+
"""python scripts/regenerate_train_data.py \
35+
--model unsloth/Llama-3.2-1B-Instruct \
36+
--concurrency 128 \
37+
--max-tokens 128 \
38+
--server-address localhost:30000 \
39+
--temperature 0.8 \
40+
--input-file-path ./cache/dataset/sharegpt_train.jsonl \
41+
--output-file-path ./cache/dataset/sharegpt_train_regen.jsonl \
42+
--num-samples 10
43+
""",
44+
disable_proxy=True,
45+
enable_hf_mirror=True,
46+
)
47+
regeneration_process.wait()
48+
self.assertEqual(regeneration_process.returncode, 0)
49+
self.assertTrue(
50+
CACHE_DIR.joinpath("dataset", "sharegpt_train_regen.jsonl").exists()
51+
)
52+
sglang_process.terminate()
53+
sglang_process.wait()
54+
55+
56+
if __name__ == "__main__":
57+
unittest.main(verbosity=2)
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import shutil
2+
import unittest
3+
from pathlib import Path
4+
5+
from tests.utils import execute_shell_command
6+
7+
CACHE_DIR = Path(__file__).parent.parent.parent.joinpath("cache")
8+
9+
10+
def replace_in_script(script_path: Path, pattern: str, replacement: str):
11+
with open(script_path, "r") as f:
12+
script = f.readlines()
13+
script = [line.replace(pattern, replacement) for line in script]
14+
with open(script_path, "w") as f:
15+
for line in script:
16+
f.write(line)
17+
18+
19+
class TestTrainEagle3(unittest.TestCase):
20+
21+
def setUp(self) -> None:
22+
# prepare data
23+
data_process = execute_shell_command(
24+
"python scripts/prepare_data.py --dataset sharegpt"
25+
)
26+
data_process.wait()
27+
28+
# modify the sccript to only train for 10 steps
29+
# add --max-num-steps 10 to the launch command
30+
script_path = Path(__file__).parent.parent.parent.joinpath(
31+
"examples", "run_llama3.1_8b_eagle3_online.sh"
32+
)
33+
with open(script_path, "r") as f:
34+
script = f.readlines()
35+
36+
# remove empty lines
37+
script = [line for line in script if line.strip()]
38+
script[-1] = script[-1].rstrip() + " --max-num-steps 10"
39+
40+
# replace meta-llama/Llama-3.1-8B-Instruct with unsloth/Llama-3.2-1B-Instruct
41+
# so that we don't need HF token for gated repo
42+
script = [
43+
line.replace(
44+
"meta-llama/Llama-3.1-8B-Instruct", "nreHieW/Llama-3.1-8B-Instruct"
45+
)
46+
for line in script
47+
]
48+
49+
# write the script back to the file
50+
with open(script_path, "w") as f:
51+
for line in script:
52+
f.write(line)
53+
54+
def test_online_train_eagle3_with_sglang_backend(self):
55+
# run training
56+
train_process = execute_shell_command(
57+
"bash examples/run_llama3.1_8b_eagle3_online.sh 2"
58+
)
59+
train_process.wait()
60+
self.assertEqual(train_process.returncode, 0)
61+
62+
def test_online_train_eagle3_with_hf_backend(self):
63+
# replace --target-model-backend sglang with --target-model-backend hf
64+
script_path = Path(__file__).parent.parent.parent.joinpath(
65+
"examples", "run_llama3.1_8b_eagle3_online.sh"
66+
)
67+
replace_in_script(
68+
script_path, "--target-model-backend sglang", "--target-model-backend hf"
69+
)
70+
71+
# run training
72+
train_process = execute_shell_command(
73+
"bash examples/run_llama3.1_8b_eagle3_online.sh 2"
74+
)
75+
train_process.wait()
76+
self.assertEqual(train_process.returncode, 0)
77+
78+
def test_online_train_eagle3_with_custom_backend(self):
79+
# replace --target-model-backend sglang with --target-model-backend custom
80+
script_path = Path(__file__).parent.parent.parent.joinpath(
81+
"examples", "run_llama3.1_8b_eagle3_online.sh"
82+
)
83+
replace_in_script(
84+
script_path,
85+
"--target-model-backend sglang",
86+
"--target-model-backend custom",
87+
)
88+
89+
# run training
90+
train_process = execute_shell_command(
91+
"bash examples/run_llama3.1_8b_eagle3_online.sh 2"
92+
)
93+
train_process.wait()
94+
self.assertEqual(train_process.returncode, 0)
95+
96+
def test_offline_train_eagle3(self):
97+
# remove the hidden states if they exist
98+
script_path = Path(__file__).parent.parent.parent.joinpath(
99+
"examples", "run_llama3.1_8b_eagle3_offline.sh"
100+
)
101+
replace_in_script(
102+
script_path,
103+
"meta-llama/Llama-3.1-8B-Instruct",
104+
"nreHieW/Llama-3.1-8B-Instruct",
105+
)
106+
replace_in_script(
107+
script_path,
108+
"--batch-size 32",
109+
"--batch-size 5",
110+
)
111+
replace_in_script(
112+
script_path,
113+
"scripts/prepare_hidden_states.py",
114+
"scripts/prepare_hidden_states.py --num-samples 10",
115+
)
116+
replace_in_script(
117+
script_path,
118+
"$ROOT_DIR/scripts/train_eagle3.py",
119+
"$ROOT_DIR/scripts/train_eagle3.py --max-num-steps 2",
120+
)
121+
122+
hidden_states_path = Path(__file__).parent.parent.parent.joinpath(
123+
"cache", "hidden_states", "sharegpt_train_Llama-3.1-8B-Instruct"
124+
)
125+
if hidden_states_path.exists():
126+
# delete the directory
127+
shutil.rmtree(hidden_states_path)
128+
129+
training_process = execute_shell_command(
130+
"bash examples/run_llama3.1_8b_eagle3_offline.sh 2",
131+
)
132+
training_process.wait()
133+
self.assertEqual(training_process.returncode, 0)
134+
135+
136+
if __name__ == "__main__":
137+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)