[Bugfix][Frontend] Guard against bad token ids (#9634)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde
2024-10-29 16:13:20 -05:00
committed by GitHub
parent 0ad216f575
commit 67bdf8e523
7 changed files with 89 additions and 17 deletions

View File

@@ -412,6 +412,12 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def get_tokenizer_async(self,
lora_request: Optional[LoRARequest] = None
) -> AnyTokenizer:
return await (
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
@overload # DEPRECATED
async def add_request_async(
self,
@@ -472,6 +478,10 @@ class _AsyncLLMEngine(LLMEngine):
if arrival_time is None:
arrival_time = time.time()
if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
request_id=request_id,
@@ -488,7 +498,7 @@ class _AsyncLLMEngine(LLMEngine):
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=self.get_tokenizer(lora_request),
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)
@@ -715,8 +725,7 @@ class AsyncLLMEngine(EngineClient):
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await (self.engine.get_tokenizer_group().
get_lora_tokenizer_async(lora_request))
return await self.engine.get_tokenizer_async(lora_request)
def start_background_loop(self) -> None:
"""Start the background loop."""