Skip to content

Commit 9691b81

Browse files
fix: enable and disable thinking for vllmwrapper
1 parent adfeb68 commit 9691b81

1 file changed

Lines changed: 15 additions & 10 deletions

File tree

graphgen/models/llm/local/vllm_wrapper.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import asyncio
12
import math
23
import uuid
34
from typing import Any, List, Optional
4-
import asyncio
55

66
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
77
from graphgen.bases.datatypes import Token
@@ -43,6 +43,7 @@ def __init__(
4343
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
4444
self.timeout = float(timeout)
4545
self.tokenizer = self.engine.engine.tokenizer.tokenizer
46+
self.enable_thinking = kwargs.get("enable_thinking", False)
4647

4748
def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> Any:
4849
messages = history or []
@@ -51,7 +52,8 @@ def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> An
5152
return self.tokenizer.apply_chat_template(
5253
messages,
5354
tokenize=False,
54-
add_generation_prompt=True
55+
add_generation_prompt=True,
56+
enable_thinking=self.enable_thinking,
5557
)
5658

5759
async def _consume_generator(self, generator):
@@ -76,10 +78,11 @@ async def generate_answer(
7678
)
7779

7880
try:
79-
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
81+
result_generator = self.engine.generate(
82+
full_prompt, sp, request_id=request_id
83+
)
8084
final_output = await asyncio.wait_for(
81-
self._consume_generator(result_generator),
82-
timeout=self.timeout
85+
self._consume_generator(result_generator), timeout=self.timeout
8386
)
8487

8588
if not final_output or not final_output.outputs:
@@ -105,13 +108,13 @@ async def generate_topk_per_token(
105108
)
106109

107110
try:
108-
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
111+
result_generator = self.engine.generate(
112+
full_prompt, sp, request_id=request_id
113+
)
109114
final_output = await asyncio.wait_for(
110-
self._consume_generator(result_generator),
111-
timeout=self.timeout
115+
self._consume_generator(result_generator), timeout=self.timeout
112116
)
113117

114-
115118
if (
116119
not final_output
117120
or not final_output.outputs
@@ -124,7 +127,9 @@ async def generate_topk_per_token(
124127
candidate_tokens = []
125128
for _, logprob_obj in top_logprobs.items():
126129
tok_str = (
127-
logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
130+
logprob_obj.decoded_token.strip()
131+
if logprob_obj.decoded_token
132+
else ""
128133
)
129134
prob = float(math.exp(logprob_obj.logprob))
130135
candidate_tokens.append(Token(tok_str, prob))

0 commit comments

Comments
 (0)