[Core] Deduplicate generate/encode logic in AsyncLLM (#31510)

Signed-off-by: njhill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2025-12-29 18:42:45 -08:00
committed by GitHub
parent 358bfd315c
commit e54ee3ea33

View File

@@ -281,6 +281,25 @@ class AsyncLLM(EngineClient):
is_pooling = isinstance(params, PoolingParams)
if (
self.vllm_config.cache_config.kv_sharing_fast_prefill
and not is_pooling
and params.prompt_logprobs
):
raise ValueError(
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, please disable it when the requests need "
"prompt logprobs"
)
if tokenization_kwargs is None:
tokenization_kwargs = {}
_validate_truncation_size(
self.model_config.max_model_len,
params.truncate_prompt_tokens,
tokenization_kwargs,
)
# Convert Input --> Request.
if isinstance(prompt, EngineCoreRequest):
request = prompt
@@ -291,7 +310,10 @@ class AsyncLLM(EngineClient):
"latter will be used, and the former will be ignored."
)
else:
assert prompt_text is None
if prompt_text is not None:
raise ValueError(
"should only provide prompt_text with EngineCoreRequest"
)
request = self.input_processor.process_inputs(
request_id,
prompt,
@@ -310,6 +332,15 @@ class AsyncLLM(EngineClient):
self.input_processor.assign_request_id(request)
# We start the output_handler on the first call to add_request() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
# Respect pause state before accepting new requests.
async with self._pause_cond:
await self._pause_cond.wait_for(lambda: not self._paused)
# Create a new output collector for the request.
queue = RequestOutputCollector(params.output_kind, request.request_id)
@@ -385,37 +416,8 @@ class AsyncLLM(EngineClient):
returning the RequestOutput back to the caller.
"""
if (
self.vllm_config.cache_config.kv_sharing_fast_prefill
and sampling_params.prompt_logprobs
):
raise ValueError(
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, please disable it when the requests need "
"prompt logprobs"
)
q: RequestOutputCollector | None = None
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
# Wait until generation is resumed if the engine is paused.
async with self._pause_cond:
await self._pause_cond.wait_for(lambda: not self._paused)
if tokenization_kwargs is None:
tokenization_kwargs = {}
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
_validate_truncation_size(
self.model_config.max_model_len,
truncate_prompt_tokens,
tokenization_kwargs,
)
q = await self.add_request(
request_id,
prompt,
@@ -639,18 +641,6 @@ class AsyncLLM(EngineClient):
q: RequestOutputCollector | None = None
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
# Respect pause state before accepting new requests.
async with self._pause_cond:
await self._pause_cond.wait_for(lambda: not self._paused)
if tokenization_kwargs is None:
tokenization_kwargs = {}
if truncate_prompt_tokens is not None:
warnings.warn(
"The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` "
@@ -660,12 +650,6 @@ class AsyncLLM(EngineClient):
stacklevel=2,
)
_validate_truncation_size(
self.model_config.max_model_len,
pooling_params.truncate_prompt_tokens,
tokenization_kwargs,
)
q = await self.add_request(
request_id,
prompt,