Add tokenization_kwargs to encode for embedding model truncation (#21033)

This commit is contained in:
Wang Yijun
2025-07-22 23:24:00 +08:00
committed by GitHub
parent 226b452a20
commit 44554a0068
3 changed files with 20 additions and 3 deletions

View File

@@ -438,6 +438,7 @@ class _AsyncLLMEngine(LLMEngine):
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Async version of
@@ -468,6 +469,7 @@ class _AsyncLLMEngine(LLMEngine):
prompt,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
tokenization_kwargs=tokenization_kwargs,
)
if isinstance(params, SamplingParams) and \
@@ -862,6 +864,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
if not self.is_running:
if self.start_engine_loop:
@@ -889,6 +892,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
tokenization_kwargs=tokenization_kwargs,
)
return stream.generator()
@@ -996,6 +1000,7 @@ class AsyncLLMEngine(EngineClient):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.
@@ -1070,6 +1075,7 @@ class AsyncLLMEngine(EngineClient):
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
):
yield LLMEngine.validate_output(output, PoolingRequestOutput)
except asyncio.CancelledError: