Skip to content

Commit a2cb6d8

Browse files
Ximingwang-09纬杭gemini-code-assist[bot]
authored
[feat]Regenerate train data from target model. (#223)
* regenerate train data from target model * Update docs/basic_usage/data_preparation.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update scripts/regenerate_train_data.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update scripts/regenerate_train_data.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update scripts/regenerate_train_data.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * lint * utf-8 --------- Co-authored-by: 纬杭 <ximing.wxm@antgroup.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 07157bd commit a2cb6d8

3 files changed

Lines changed: 390 additions & 0 deletions

File tree

docs/basic_usage/data_preparation.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,27 @@
22

33
In this section, we will introduce how to prepare the dataset for both online and offline training. As mentioned in the [Overview](#-overview) section, online training only requires the raw dataset while offline training requires the hidden states generated from the target model. In the section below, we will introduce how to prepare both the raw dataset and the hidden states.
44

5+
### 🔄 Regenerate Train Dataset
6+
7+
Many public datasets were not generated by your target model, which may lead to misalignment between the draft model’s outputs and the target model’s behavior — reducing acceptance rate and inference efficiency. To address this, we **recommend regenerating the dataset using the target model**, which better aligns the draft model with the target model’s output distribution, improving acceptance length and overall performance.
8+
9+
Run the following command to regenerate your dataset:
10+
11+
```bash
12+
python3 \
13+
scripts/regenerate_data.py \
14+
--model <target-model-path> \
15+
--input-file-path <jsonl-file-path> \
16+
--output-file-path <regenerated-jsonl-file-path> \
17+
--batch-size 128 \
18+
--tp-size 8 \
19+
--num-samples 1000 \
20+
--port 30000 \
21+
--temperature 0 \
22+
--mem-fraction-static 0.85 \
23+
--auto-launch-server
24+
```
25+
526
### ☁️ Prepare Online Training Dataset
627

728
We have provided a script to prepare some sample datasets including ultrachat (200k) and sharegpt (120k) for demo purpose. You can easily process the dataset by running the following command. The jsonl files will be placed in the `cache/dataset/<dataset_name>` directory of the project path by default. These datasets will be processed into `jsonl` files, which are the raw dataset ready for online training!

examples/run_regenerate_data.sh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
2+
ROOT_DIR=$(dirname $SCRIPT_DIR)
3+
4+
5+
# regenerate eagle3 train data
6+
NUM_GPUS=${1:-8}
7+
8+
python3 \
9+
$ROOT_DIR/scripts/regenerate_data.py \
10+
--model Qwen/QwQ-32B \
11+
--input-file-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
12+
--output-file-path $ROOT_DIR/cache/dataset/sharegpt_regenerate.jsonl \
13+
--batch-size 128 \
14+
--tp-size $NUM_GPUS \
15+
--num-samples 1000 \
16+
--port 30000 \
17+
--temperature 0 \
18+
--mem-fraction-static 0.85 \
19+
--auto-launch-server

scripts/regenerate_train_data.py

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
"""
2+
This script will re-generate the dataset from target model,
3+
which better aligns the draft model with the target model’s output distribution.
4+
"""
5+
6+
import argparse
7+
import json
8+
import signal
9+
import socket
10+
import subprocess
11+
import sys
12+
import time
13+
from typing import List
14+
15+
import requests
16+
from tqdm import tqdm
17+
from transformers import AutoTokenizer
18+
19+
# Global variables will be initialized in main function
20+
MODEL = None
21+
MAX_TOKENS = None
22+
BATCH_SIZE = None
23+
TEMPERATURE = None
24+
BASE_URL = None
25+
HEADERS = {"Content-Type": "application/json"}
26+
SERVER_PROCESS = None
27+
28+
29+
def parse_arguments():
30+
"""Parse command line arguments"""
31+
parser = argparse.ArgumentParser(
32+
description="Re-generate training data using sglang model server"
33+
)
34+
parser.add_argument("--model", type=str, required=True)
35+
parser.add_argument(
36+
"--max-tokens",
37+
type=int,
38+
default=4096,
39+
help="Maximum number of tokens (default: 4096)",
40+
)
41+
parser.add_argument("--batch-size", type=int, default=64)
42+
parser.add_argument("--temperature", type=float, default=0)
43+
parser.add_argument("--port", type=int, default=30000)
44+
parser.add_argument("--input-file-path", type=str, required=True)
45+
parser.add_argument("--output-file-path", type=str, required=True)
46+
parser.add_argument("--tp-size", type=int, default=8)
47+
parser.add_argument("--dp-size", type=int, default=1)
48+
parser.add_argument("--mem-fraction-static", type=float, default=0.85)
49+
parser.add_argument("--max-running-requests", type=int, default=128)
50+
parser.add_argument(
51+
"--auto-launch-server",
52+
action="store_true",
53+
help="Automatically launch sglang server if port is available",
54+
)
55+
parser.add_argument("--num-samples", type=int, default=None)
56+
57+
return parser.parse_args()
58+
59+
60+
def is_port_in_use(port: int) -> bool:
61+
"""Check if a port is in use"""
62+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
63+
try:
64+
s.bind(("localhost", port))
65+
return False
66+
except OSError:
67+
return True
68+
69+
70+
def launch_sglang_server(
71+
model_path: str,
72+
port: int,
73+
tp_size: int,
74+
dp_size: int,
75+
mem_fraction_static: float,
76+
max_running_requests: int,
77+
) -> subprocess.Popen:
78+
"""Launch sglang server"""
79+
cmd = [
80+
"python3",
81+
"-m",
82+
"sglang.launch_server",
83+
"--model",
84+
model_path,
85+
"--trust-remote-code",
86+
"--tp-size",
87+
str(tp_size),
88+
"--dp-size",
89+
str(dp_size),
90+
"--enable-cache-report",
91+
"--dtype",
92+
"bfloat16",
93+
"--log-level",
94+
"info",
95+
"--mem-fraction-static",
96+
str(mem_fraction_static),
97+
"--port",
98+
str(port),
99+
"--max-running-requests",
100+
str(max_running_requests),
101+
]
102+
103+
print(f"Launching sglang server with command:")
104+
print(" ".join(cmd))
105+
106+
# Start the server process
107+
process = subprocess.Popen(cmd)
108+
return process
109+
110+
111+
def wait_for_server_ready(port: int, timeout: int = 3600) -> bool:
112+
"""Wait for server to be ready"""
113+
print(f"Waiting for server to be ready at localhost:{port}...")
114+
start_time = time.time()
115+
116+
while time.time() - start_time < timeout:
117+
if is_port_in_use(int(port)):
118+
# Port is in use, try to make a simple request
119+
try:
120+
response = requests.get(f"http://localhost:{port}/health", timeout=5)
121+
if response.status_code == 200:
122+
print("Server is ready!")
123+
return True
124+
except requests.exceptions.RequestException:
125+
pass
126+
time.sleep(5)
127+
128+
print(f"Server failed to start within {timeout} seconds")
129+
return False
130+
131+
132+
def cleanup_server():
133+
"""Clean up server process"""
134+
global SERVER_PROCESS
135+
if SERVER_PROCESS and SERVER_PROCESS.poll() is None:
136+
print("Shutting down sglang server...")
137+
SERVER_PROCESS.terminate()
138+
try:
139+
SERVER_PROCESS.wait(timeout=30)
140+
except subprocess.TimeoutExpired:
141+
SERVER_PROCESS.kill()
142+
print("Server shutdown complete")
143+
144+
145+
def signal_handler(sig, frame):
146+
"""Handle interrupt signals"""
147+
print("\nReceived interrupt signal, cleaning up...")
148+
cleanup_server()
149+
sys.exit(0)
150+
151+
152+
def call_sglang_batch(prompts: List[str]) -> List[str]:
153+
"""Send a batch of prompts to sglang /v1/completions."""
154+
global MODEL, MAX_TOKENS, TEMPERATURE, BASE_URL, HEADERS
155+
156+
payload = {
157+
"model": MODEL,
158+
"prompt": prompts,
159+
"max_tokens": MAX_TOKENS,
160+
"temperature": TEMPERATURE,
161+
"skip_special_tokens": False,
162+
}
163+
164+
resp = requests.post(BASE_URL, headers=HEADERS, json=payload, timeout=600)
165+
resp.raise_for_status()
166+
data = resp.json()
167+
return [choice["text"].strip() for choice in data["choices"]]
168+
169+
170+
def main():
171+
global MODEL, MAX_TOKENS, BATCH_SIZE, TEMPERATURE, BASE_URL, SERVER_PROCESS
172+
173+
# Parse command line arguments
174+
args = parse_arguments()
175+
176+
# Set global variables
177+
MODEL = args.model
178+
MAX_TOKENS = args.max_tokens
179+
BATCH_SIZE = args.batch_size
180+
TEMPERATURE = args.temperature
181+
BASE_URL = f"http://localhost:{args.port}/v1/completions"
182+
input_file_path = args.input_file_path
183+
output_file_path = args.output_file_path
184+
185+
# Validate parameters
186+
if not (0.0 <= TEMPERATURE <= 1.0):
187+
raise ValueError("Temperature must be between 0.0 and 1.0")
188+
189+
if MAX_TOKENS <= 0:
190+
raise ValueError("Max tokens must be greater than 0")
191+
192+
if BATCH_SIZE <= 0:
193+
raise ValueError("Batch size must be greater than 0")
194+
195+
# Check if server needs to be launched
196+
if args.auto_launch_server:
197+
port = args.port
198+
if not is_port_in_use(port):
199+
print(f"Port {port} is available, launching sglang server...")
200+
try:
201+
SERVER_PROCESS = launch_sglang_server(
202+
model_path=args.model,
203+
port=port,
204+
tp_size=args.tp_size,
205+
dp_size=args.dp_size,
206+
mem_fraction_static=args.mem_fraction_static,
207+
max_running_requests=args.max_running_requests,
208+
)
209+
210+
# Wait for server to be ready
211+
if not wait_for_server_ready(port):
212+
cleanup_server()
213+
raise RuntimeError("Failed to start server")
214+
215+
print("Server launched successfully!")
216+
except Exception as e:
217+
print(f"Failed to launch server: {e}")
218+
sys.exit(1)
219+
else:
220+
print(f"Port {port} is already in use, assuming server is running")
221+
else:
222+
port = args.port
223+
if not is_port_in_use(port):
224+
print(
225+
f"Warning: Port {port} is not in use. Please ensure sglang server is running."
226+
)
227+
228+
# Set up signal handlers for clean shutdown
229+
signal.signal(signal.SIGINT, signal_handler)
230+
signal.signal(signal.SIGTERM, signal_handler)
231+
232+
print(f"Configuration:")
233+
print(f" Model path: {MODEL}")
234+
print(f" Max tokens: {MAX_TOKENS}")
235+
print(f" Batch size: {BATCH_SIZE}")
236+
print(f" Temperature: {TEMPERATURE}")
237+
print(f" API URL: {BASE_URL}")
238+
print(f" Input file: {input_file_path}")
239+
print(f" Output file: {output_file_path}")
240+
print("-" * 50)
241+
242+
tokenizer = AutoTokenizer.from_pretrained(MODEL)
243+
244+
# Variables for batch processing
245+
batch_prompts = []
246+
batch_data = []
247+
248+
# Count total lines for progress bar
249+
print("Counting total lines in file...")
250+
with open(input_file_path, "r") as f:
251+
total_lines = sum(1 for _ in f)
252+
total_lines = (
253+
min(args.num_samples, total_lines) if args.num_samples else total_lines
254+
)
255+
print(f"Total {total_lines} lines to process")
256+
257+
# Create progress bar
258+
pbar = tqdm(total=total_lines, desc="Processing", unit="item")
259+
260+
processed_count = 0
261+
262+
try:
263+
with open(input_file_path, "r") as input_file, open(
264+
output_file_path, "w"
265+
) as output_file_handle:
266+
267+
for _, line in zip(range(total_lines), input_file):
268+
data = json.loads(line)
269+
messages = data["conversations"]
270+
271+
# Remove original last assistant message
272+
if messages[-1]["role"] == "assistant":
273+
messages = messages[:-1]
274+
prompt = tokenizer.apply_chat_template(
275+
messages, tokenize=False, add_generation_prompt=True
276+
)
277+
278+
# Add to batch
279+
batch_prompts.append(prompt)
280+
batch_data.append(data)
281+
282+
# Process when batch reaches specified size
283+
if len(batch_prompts) == BATCH_SIZE:
284+
# Generate outputs
285+
outputs = call_sglang_batch(batch_prompts)
286+
287+
# Process each output
288+
for i, output in enumerate(outputs):
289+
# Create assistant message
290+
assistant_message = {"role": "assistant", "content": output}
291+
292+
# Add assistant message to original conversations
293+
batch_data[i]["conversations"].append(assistant_message)
294+
295+
# Write to output file
296+
output_file_handle.write(
297+
json.dumps(batch_data[i], ensure_ascii=False) + "\n"
298+
)
299+
300+
processed_count += 1
301+
pbar.update(1)
302+
303+
# Update progress bar description
304+
pbar.set_postfix(
305+
{
306+
"Processed": processed_count,
307+
"Current batch": len(batch_prompts),
308+
}
309+
)
310+
311+
# Clear batch
312+
batch_prompts = []
313+
batch_data = []
314+
315+
# Process remaining data that doesn't fill a complete batch
316+
if batch_prompts:
317+
outputs = call_sglang_batch(batch_prompts)
318+
319+
# Process each output
320+
for i, output in enumerate(outputs):
321+
assistant_message = {"role": "assistant", "content": output}
322+
323+
batch_data[i]["conversations"].append(assistant_message)
324+
output_file_handle.write(
325+
json.dumps(batch_data[i], ensure_ascii=False) + "\n"
326+
)
327+
328+
# Update processing count and progress bar
329+
processed_count += 1
330+
pbar.update(1)
331+
332+
# Update progress bar description
333+
pbar.set_postfix(
334+
{"Processed": processed_count, "Last batch": len(batch_prompts)}
335+
)
336+
337+
# Close progress bar
338+
pbar.close()
339+
print(f"\nProcessing completed! Total {processed_count} lines processed")
340+
341+
except Exception as e:
342+
print(f"Error during processing: {e}")
343+
raise
344+
finally:
345+
# Clean up server if we launched it
346+
cleanup_server()
347+
348+
349+
if __name__ == "__main__":
350+
main()

0 commit comments

Comments
 (0)