1+ import asyncio
12import math
23import uuid
34from typing import Any , List , Optional
4- import asyncio
55
66from graphgen .bases .base_llm_wrapper import BaseLLMWrapper
77from graphgen .bases .datatypes import Token
@@ -43,6 +43,7 @@ def __init__(
4343 self .engine = AsyncLLMEngine .from_engine_args (engine_args )
4444 self .timeout = float (timeout )
4545 self .tokenizer = self .engine .engine .tokenizer .tokenizer
46+ self .enable_thinking = kwargs .get ("enable_thinking" , False )
4647
4748 def _build_inputs (self , prompt : str , history : Optional [List [dict ]] = None ) -> Any :
4849 messages = history or []
@@ -51,7 +52,8 @@ def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> An
5152 return self .tokenizer .apply_chat_template (
5253 messages ,
5354 tokenize = False ,
54- add_generation_prompt = True
55+ add_generation_prompt = True ,
56+ enable_thinking = self .enable_thinking ,
5557 )
5658
5759 async def _consume_generator (self , generator ):
@@ -76,10 +78,11 @@ async def generate_answer(
7678 )
7779
7880 try :
79- result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
81+ result_generator = self .engine .generate (
82+ full_prompt , sp , request_id = request_id
83+ )
8084 final_output = await asyncio .wait_for (
81- self ._consume_generator (result_generator ),
82- timeout = self .timeout
85+ self ._consume_generator (result_generator ), timeout = self .timeout
8386 )
8487
8588 if not final_output or not final_output .outputs :
@@ -105,13 +108,13 @@ async def generate_topk_per_token(
105108 )
106109
107110 try :
108- result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
111+ result_generator = self .engine .generate (
112+ full_prompt , sp , request_id = request_id
113+ )
109114 final_output = await asyncio .wait_for (
110- self ._consume_generator (result_generator ),
111- timeout = self .timeout
115+ self ._consume_generator (result_generator ), timeout = self .timeout
112116 )
113117
114-
115118 if (
116119 not final_output
117120 or not final_output .outputs
@@ -124,7 +127,9 @@ async def generate_topk_per_token(
124127 candidate_tokens = []
125128 for _ , logprob_obj in top_logprobs .items ():
126129 tok_str = (
127- logprob_obj .decoded_token .strip () if logprob_obj .decoded_token else ""
130+ logprob_obj .decoded_token .strip ()
131+ if logprob_obj .decoded_token
132+ else ""
128133 )
129134 prob = float (math .exp (logprob_obj .logprob ))
130135 candidate_tokens .append (Token (tok_str , prob ))
0 commit comments