[Bugfix][Frontend] Guard against bad token ids (#9634)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
@@ -412,6 +412,12 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
async def get_tokenizer_async(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> AnyTokenizer:
|
||||
return await (
|
||||
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
|
||||
|
||||
@overload # DEPRECATED
|
||||
async def add_request_async(
|
||||
self,
|
||||
@@ -472,6 +478,10 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
if self.tokenizer is not None:
|
||||
tokenizer = await self.get_tokenizer_async(lora_request)
|
||||
self._validate_token_prompt(prompt, tokenizer=tokenizer)
|
||||
|
||||
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
@@ -488,7 +498,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
# implementation in the LLMEngine
|
||||
params = await build_guided_decoding_logits_processor_async(
|
||||
sampling_params=params,
|
||||
tokenizer=self.get_tokenizer(lora_request),
|
||||
tokenizer=await self.get_tokenizer_async(lora_request),
|
||||
default_guided_backend=self.decoding_config.
|
||||
guided_decoding_backend)
|
||||
|
||||
@@ -715,8 +725,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return await (self.engine.get_tokenizer_group().
|
||||
get_lora_tokenizer_async(lora_request))
|
||||
return await self.engine.get_tokenizer_async(lora_request)
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
|
||||
Reference in New Issue
Block a user