11import json
22import re
3- import time
3+
4+ from loguru import logger
45
56from data_juicer .ops .base_op import OPERATORS , Mapper
67from data_juicer .utils .lazy_loader import LazyLoader
1213OP_NAME = "generate_challenging_qa_mapper"
1314
1415
15- def retry_on_error (func , max_retries = 5 , delay = 1 ):
16- """
17- Decorator function with retry mechanism
18- :param func: function to be retried
19- :param max_retries: maximum number of retries
20- :param delay: delay time before each retry (seconds)
21- :return: function execution result
22- """
23-
24- def wrapper (* args , ** kwargs ):
25- retries = 0
26- while retries < max_retries :
27- try :
28- return func (* args , ** kwargs )
29- except Exception as e :
30- retries += 1
31- print (f"Error: { e } , retry { retries } /{ max_retries } ..." )
32- if retries >= max_retries :
33- raise
34- time .sleep (delay )
35-
36- return wrapper
37-
38-
3916# TODO: Extend LLM-based OPs into API-based implementation.
4017@OPERATORS .register_module (OP_NAME )
4118class GenerateChallengingQAMapper (Mapper ):
@@ -110,9 +87,8 @@ def __init__(
11087 self .user_prompt_multihop = user_prompt_multihop
11188 self .extract_prompt_qa = extract_prompt_qa
11289
113- # tensor_parallel_size = torch.cuda.device_count()
11490 model_params = {}
115- model_params ["tensor_parallel_size" ] = 4
91+ model_params ["tensor_parallel_size" ] = torch . cuda . device_count ()
11692 self .model_key = prepare_model (model_type = "vllm" , pretrained_model_name_or_path = hf_model , ** model_params )
11793 self .sampling_params = vllm .SamplingParams (
11894 temperature = 0.9 , top_p = 0.95 , top_k = 40 , repetition_penalty = 1.1 , max_tokens = 2048
@@ -129,17 +105,16 @@ def extract_json(self, text):
129105 json_data = json .loads (json_str )
130106 return json_data
131107 except json .JSONDecodeError as e :
132- print (f"JSON parse error: { e } " )
108+ logger . warning (f"JSON parse error: { e } " )
133109 return None
134110 else :
135- print ( "None of valid JSON data" )
111+ logger . warning ( "No valid JSON data found in model output. " )
136112 return None
137113
138- @retry_on_error
139114 def process_single (self , sample = None , rank = None ):
140115
141116 if self .category is None :
142- print ("This OP requires processing multiple fields, and you need to specify valid `category`" )
117+ raise ValueError ("This OP requires processing multiple fields, and you need to specify a valid `category`" )
143118
144119 model , _ = get_model (self .model_key , rank , self .use_cuda ())
145120
@@ -150,34 +125,47 @@ def process_single(self, sample=None, rank=None):
150125 "content" : self .user_prompt_background .format (category = self .category ).replace ("Qwen" , self .model_name ),
151126 },
152127 ]
153- background = model .chat (messages , self .sampling_params )
154128
155- messages .extend (
156- [
157- {"role" : "system" , "content" : background [0 ].outputs [0 ].text },
158- {"role" : "user" , "content" : self .user_prompt_subquestion .replace ("Qwen" , self .model_name )},
159- ]
160- )
161- sub_questions = model .chat (messages , self .sampling_params )
162-
163- messages .extend (
164- [
165- {"role" : "system" , "content" : sub_questions [0 ].outputs [0 ].text },
166- {"role" : "user" , "content" : self .user_prompt_multihop .replace ("Qwen" , self .model_name )},
167- ]
168- )
169- multihop = model .chat (messages , self .sampling_params )
170-
171- messages .extend (
172- [
173- {"role" : "system" , "content" : multihop [0 ].outputs [0 ].text },
174- {"role" : "user" , "content" : self .extract_prompt_qa .replace ("Qwen" , self .model_name )},
175- ]
176- )
177- qa = model .chat (messages , self .sampling_params )
178-
179- qa = self .extract_json (qa [0 ].outputs [0 ].text )
180- qa ["thinking" ] = multihop [0 ].outputs [0 ].text
129+ max_retries = 5
130+ for attempt in range (max_retries + 1 ): # 包括首次尝试
131+ try :
132+ background = model .chat (messages , self .sampling_params )
133+
134+ messages .extend (
135+ [
136+ {"role" : "system" , "content" : background [0 ].outputs [0 ].text },
137+ {"role" : "user" , "content" : self .user_prompt_subquestion .replace ("Qwen" , self .model_name )},
138+ ]
139+ )
140+ sub_questions = model .chat (messages , self .sampling_params )
141+
142+ messages .extend (
143+ [
144+ {"role" : "system" , "content" : sub_questions [0 ].outputs [0 ].text },
145+ {"role" : "user" , "content" : self .user_prompt_multihop .replace ("Qwen" , self .model_name )},
146+ ]
147+ )
148+ multihop = model .chat (messages , self .sampling_params )
149+
150+ messages .extend (
151+ [
152+ {"role" : "system" , "content" : multihop [0 ].outputs [0 ].text },
153+ {"role" : "user" , "content" : self .extract_prompt_qa .replace ("Qwen" , self .model_name )},
154+ ]
155+ )
156+ qa = model .chat (messages , self .sampling_params )
157+
158+ qa = self .extract_json (qa [0 ].outputs [0 ].text )
159+ if qa is None :
160+ raise ValueError ("Failed to extract valid JSON from model output." )
161+ qa ["thinking" ] = multihop [0 ].outputs [0 ].text
162+ except Exception as e :
163+ if attempt < max_retries : # 如果还有重试机会
164+ logger .warning (f"Attempt { attempt + 1 } failed with error: { e } . Retrying..." )
165+ continue
166+ else : # 重试次数已用完
167+ logger .warning (f"All { max_retries + 1 } attempts failed." )
168+ raise # 重新抛出最后一次的异常
181169
182170 sample .clear ()
183171 sample .update (qa )
0 commit comments