[Core] Deduplicate generate/encode logic in AsyncLLM (#31510)
Signed-off-by: njhill <nickhill123@gmail.com>
This commit is contained in:
@@ -281,6 +281,25 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
is_pooling = isinstance(params, PoolingParams)
|
||||
|
||||
if (
|
||||
self.vllm_config.cache_config.kv_sharing_fast_prefill
|
||||
and not is_pooling
|
||||
and params.prompt_logprobs
|
||||
):
|
||||
raise ValueError(
|
||||
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||
"prompt tokens, please disable it when the requests need "
|
||||
"prompt logprobs"
|
||||
)
|
||||
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
_validate_truncation_size(
|
||||
self.model_config.max_model_len,
|
||||
params.truncate_prompt_tokens,
|
||||
tokenization_kwargs,
|
||||
)
|
||||
|
||||
# Convert Input --> Request.
|
||||
if isinstance(prompt, EngineCoreRequest):
|
||||
request = prompt
|
||||
@@ -291,7 +310,10 @@ class AsyncLLM(EngineClient):
|
||||
"latter will be used, and the former will be ignored."
|
||||
)
|
||||
else:
|
||||
assert prompt_text is None
|
||||
if prompt_text is not None:
|
||||
raise ValueError(
|
||||
"should only provide prompt_text with EngineCoreRequest"
|
||||
)
|
||||
request = self.input_processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
@@ -310,6 +332,15 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
self.input_processor.assign_request_id(request)
|
||||
|
||||
# We start the output_handler on the first call to add_request() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
# to handle startup failure gracefully in the OpenAI server.
|
||||
self._run_output_handler()
|
||||
|
||||
# Respect pause state before accepting new requests.
|
||||
async with self._pause_cond:
|
||||
await self._pause_cond.wait_for(lambda: not self._paused)
|
||||
|
||||
# Create a new output collector for the request.
|
||||
queue = RequestOutputCollector(params.output_kind, request.request_id)
|
||||
|
||||
@@ -385,37 +416,8 @@ class AsyncLLM(EngineClient):
|
||||
returning the RequestOutput back to the caller.
|
||||
"""
|
||||
|
||||
if (
|
||||
self.vllm_config.cache_config.kv_sharing_fast_prefill
|
||||
and sampling_params.prompt_logprobs
|
||||
):
|
||||
raise ValueError(
|
||||
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||
"prompt tokens, please disable it when the requests need "
|
||||
"prompt logprobs"
|
||||
)
|
||||
|
||||
q: RequestOutputCollector | None = None
|
||||
try:
|
||||
# We start the output_handler on the first call to generate() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
# to handle startup failure gracefully in the OpenAI server.
|
||||
self._run_output_handler()
|
||||
|
||||
# Wait until generation is resumed if the engine is paused.
|
||||
async with self._pause_cond:
|
||||
await self._pause_cond.wait_for(lambda: not self._paused)
|
||||
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
|
||||
|
||||
_validate_truncation_size(
|
||||
self.model_config.max_model_len,
|
||||
truncate_prompt_tokens,
|
||||
tokenization_kwargs,
|
||||
)
|
||||
|
||||
q = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
@@ -639,18 +641,6 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
q: RequestOutputCollector | None = None
|
||||
try:
|
||||
# We start the output_handler on the first call to generate() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
# to handle startup failure gracefully in the OpenAI server.
|
||||
self._run_output_handler()
|
||||
|
||||
# Respect pause state before accepting new requests.
|
||||
async with self._pause_cond:
|
||||
await self._pause_cond.wait_for(lambda: not self._paused)
|
||||
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
warnings.warn(
|
||||
"The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` "
|
||||
@@ -660,12 +650,6 @@ class AsyncLLM(EngineClient):
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
_validate_truncation_size(
|
||||
self.model_config.max_model_len,
|
||||
pooling_params.truncate_prompt_tokens,
|
||||
tokenization_kwargs,
|
||||
)
|
||||
|
||||
q = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
|
||||
Reference in New Issue
Block a user