Skip to content

Commit b7e2654

Browse files
committed
feat: update MindGYM op with modified code
1 parent 7f2c203 commit b7e2654

File tree

1 file changed

+46
-58
lines changed

1 file changed

+46
-58
lines changed

data_juicer/ops/mapper/generate_challenging_qa_mapper.py

Lines changed: 46 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import re
3-
import time
3+
4+
from loguru import logger
45

56
from data_juicer.ops.base_op import OPERATORS, Mapper
67
from data_juicer.utils.lazy_loader import LazyLoader
@@ -12,30 +13,6 @@
1213
OP_NAME = "generate_challenging_qa_mapper"
1314

1415

15-
def retry_on_error(func, max_retries=5, delay=1):
16-
"""
17-
Decorator function with retry mechanism
18-
:param func: function to be retried
19-
:param max_retries: maximum number of retries
20-
:param delay: delay time before each retry (seconds)
21-
:return: function execution result
22-
"""
23-
24-
def wrapper(*args, **kwargs):
25-
retries = 0
26-
while retries < max_retries:
27-
try:
28-
return func(*args, **kwargs)
29-
except Exception as e:
30-
retries += 1
31-
print(f"Error: {e}, retry {retries}/{max_retries}...")
32-
if retries >= max_retries:
33-
raise
34-
time.sleep(delay)
35-
36-
return wrapper
37-
38-
3916
# TODO: Extend LLM-based OPs into API-based implementation.
4017
@OPERATORS.register_module(OP_NAME)
4118
class GenerateChallengingQAMapper(Mapper):
@@ -110,9 +87,8 @@ def __init__(
11087
self.user_prompt_multihop = user_prompt_multihop
11188
self.extract_prompt_qa = extract_prompt_qa
11289

113-
# tensor_parallel_size = torch.cuda.device_count()
11490
model_params = {}
115-
model_params["tensor_parallel_size"] = 4
91+
model_params["tensor_parallel_size"] = torch.cuda.device_count()
11692
self.model_key = prepare_model(model_type="vllm", pretrained_model_name_or_path=hf_model, **model_params)
11793
self.sampling_params = vllm.SamplingParams(
11894
temperature=0.9, top_p=0.95, top_k=40, repetition_penalty=1.1, max_tokens=2048
@@ -129,17 +105,16 @@ def extract_json(self, text):
129105
json_data = json.loads(json_str)
130106
return json_data
131107
except json.JSONDecodeError as e:
132-
print(f"JSON parse error: {e}")
108+
logger.warning(f"JSON parse error: {e}")
133109
return None
134110
else:
135-
print("None of valid JSON data")
111+
logger.warning("No valid JSON data found in model output.")
136112
return None
137113

138-
@retry_on_error
139114
def process_single(self, sample=None, rank=None):
140115

141116
if self.category is None:
142-
print("This OP requires processing multiple fields, and you need to specify valid `category`")
117+
raise ValueError("This OP requires processing multiple fields, and you need to specify a valid `category`")
143118

144119
model, _ = get_model(self.model_key, rank, self.use_cuda())
145120

@@ -150,34 +125,47 @@ def process_single(self, sample=None, rank=None):
150125
"content": self.user_prompt_background.format(category=self.category).replace("Qwen", self.model_name),
151126
},
152127
]
153-
background = model.chat(messages, self.sampling_params)
154128

155-
messages.extend(
156-
[
157-
{"role": "system", "content": background[0].outputs[0].text},
158-
{"role": "user", "content": self.user_prompt_subquestion.replace("Qwen", self.model_name)},
159-
]
160-
)
161-
sub_questions = model.chat(messages, self.sampling_params)
162-
163-
messages.extend(
164-
[
165-
{"role": "system", "content": sub_questions[0].outputs[0].text},
166-
{"role": "user", "content": self.user_prompt_multihop.replace("Qwen", self.model_name)},
167-
]
168-
)
169-
multihop = model.chat(messages, self.sampling_params)
170-
171-
messages.extend(
172-
[
173-
{"role": "system", "content": multihop[0].outputs[0].text},
174-
{"role": "user", "content": self.extract_prompt_qa.replace("Qwen", self.model_name)},
175-
]
176-
)
177-
qa = model.chat(messages, self.sampling_params)
178-
179-
qa = self.extract_json(qa[0].outputs[0].text)
180-
qa["thinking"] = multihop[0].outputs[0].text
129+
max_retries = 5
130+
for attempt in range(max_retries + 1): # 包括首次尝试
131+
try:
132+
background = model.chat(messages, self.sampling_params)
133+
134+
messages.extend(
135+
[
136+
{"role": "system", "content": background[0].outputs[0].text},
137+
{"role": "user", "content": self.user_prompt_subquestion.replace("Qwen", self.model_name)},
138+
]
139+
)
140+
sub_questions = model.chat(messages, self.sampling_params)
141+
142+
messages.extend(
143+
[
144+
{"role": "system", "content": sub_questions[0].outputs[0].text},
145+
{"role": "user", "content": self.user_prompt_multihop.replace("Qwen", self.model_name)},
146+
]
147+
)
148+
multihop = model.chat(messages, self.sampling_params)
149+
150+
messages.extend(
151+
[
152+
{"role": "system", "content": multihop[0].outputs[0].text},
153+
{"role": "user", "content": self.extract_prompt_qa.replace("Qwen", self.model_name)},
154+
]
155+
)
156+
qa = model.chat(messages, self.sampling_params)
157+
158+
qa = self.extract_json(qa[0].outputs[0].text)
159+
if qa is None:
160+
raise ValueError("Failed to extract valid JSON from model output.")
161+
qa["thinking"] = multihop[0].outputs[0].text
162+
except Exception as e:
163+
if attempt < max_retries: # 如果还有重试机会
164+
logger.warning(f"Attempt {attempt + 1} failed with error: {e}. Retrying...")
165+
continue
166+
else: # 重试次数已用完
167+
logger.warning(f"All {max_retries + 1} attempts failed.")
168+
raise # 重新抛出最后一次的异常
181169

182170
sample.clear()
183171
sample.update(qa)

0 commit comments

Comments
 (0)