Skip to content

Commit 020a856

Browse files
authored
Add --is-preformatted flag to prepare_hidden_states.py (#350)
* Add --is-preformatted flag to prepare_hidden_states.py Added support for preformatted input data in prepare_hidden_states.py, matching the existing flag in train_eagle3.py. This allows users to skip chat template application when their data already has the template applied. Changes: - Added --is-preformatted argument to data group - Updated cache key to include is_preformatted for proper caching - Pass is_preformatted to build_eagle3_dataset() * Update documentation for --is-preformatted flag in prepare_hidden_states.py - Updated script docstring with usage example for --is-preformatted - Updated data_preparation.md to document --is-preformatted for offline training * Address code review: add --output-path to docstring example Added back the --output-path argument to the first usage example in the docstring for clarity and consistency with the pre-formatted data example.
1 parent 381476b commit 020a856

2 files changed

Lines changed: 33 additions & 1 deletion

File tree

docs/basic_usage/data_preparation.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,28 @@ This format is useful when you have pre-formatted prompts that were used during
9494
To use pre-formatted datasets, add the `--is-preformatted` flag to your training command. Note that the `--chat-template` parameter is still needed and should match the template used in your pre-formatted text, as it is used to identify user/assistant tokens to determine the assistant spans and generate the corresponding loss mask.
9595

9696
```bash
97+
# Online training with pre-formatted data
9798
torchrun --standalone --nproc_per_node 8 \
9899
scripts/train_eagle3.py \
99100
--is-preformatted \
100101
--train-data-path ./your_preformatted_dataset.jsonl \
101102
# ... other arguments
102103
```
103104

105+
For offline training, you can also use `--is-preformatted` when generating hidden states:
106+
107+
```bash
108+
# Generate hidden states from pre-formatted data
109+
torchrun --nproc_per_node=8 \
110+
scripts/prepare_hidden_states.py \
111+
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
112+
--data-path ./your_preformatted_dataset.jsonl \
113+
--output-path ./cache/hidden_states \
114+
--chat-template llama3 \
115+
--is-preformatted \
116+
--max-length 2048
117+
```
118+
104119
Once you have the `jsonl` file ready, you can proceed with online training or generate hidden states for offline training. See the Training guide for more details.
105120

106121

scripts/prepare_hidden_states.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@
1919
--batch-size 32 \
2020
--num-samples 1000 \
2121
--output-path ./cache/hidden_states
22+
23+
For pre-formatted data (with chat template already applied), add --is-preformatted:
24+
torchrun --nproc_per_node=8 \
25+
scripts/prepare_hidden_states.py \
26+
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
27+
--enable-aux-hidden-states \
28+
--data-path ./cache/dataset/preformatted_data.jsonl \
29+
--output-path ./cache/hidden_states \
30+
--chat-template llama3 \
31+
--is-preformatted \
32+
--max-length 2048
2233
"""
2334

2435
import argparse
@@ -73,6 +84,11 @@ def parse_args():
7384
data_group.add_argument("--data-path", type=str, required=True)
7485
data_group.add_argument("--max-length", type=int, default=2048)
7586
data_group.add_argument("--chat-template", type=str, default="llama3")
87+
data_group.add_argument(
88+
"--is-preformatted",
89+
action="store_true",
90+
help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.",
91+
)
7692
data_group.add_argument("--num-samples", type=int, default=None)
7793
data_group.add_argument("--build-dataset-num-proc", type=int, default=8)
7894

@@ -558,7 +574,7 @@ def main():
558574
tokenizer = AutoTokenizer.from_pretrained(
559575
args.target_model_path, trust_remote_code=True
560576
)
561-
cache_params_string = f"{args.data_path}-{args.max_length}-{args.chat_template}-{args.target_model_path}-{args.num_samples}"
577+
cache_params_string = f"{args.data_path}-{args.max_length}-{args.chat_template}-{args.target_model_path}-{args.num_samples}-{args.is_preformatted}"
562578
cache_key = hashlib.md5(cache_params_string.encode()).hexdigest()
563579

564580
# Preprocess on complete, un-sharded dataset
@@ -572,6 +588,7 @@ def main():
572588
cache_dir=os.path.join(args.cache_dir, "processed_dataset"),
573589
cache_key=cache_key,
574590
is_vlm=args.is_vlm,
591+
is_preformatted=args.is_preformatted,
575592
processor=processor,
576593
num_proc=args.build_dataset_num_proc,
577594
)

0 commit comments

Comments
 (0)