[V1] Avoid redundant input processing in n>1 case (#14985)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-03-20 22:24:10 -07:00
committed by GitHub
parent 7297941b38
commit da6ea29f7a
13 changed files with 85 additions and 145 deletions

View File

@@ -4,6 +4,7 @@ import asyncio
import logging
import os
from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Optional, Union
import numpy as np
@@ -25,6 +26,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv, kill_process_tree
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
@@ -177,34 +179,45 @@ class AsyncLLM(EngineClient):
) -> asyncio.Queue[RequestOutput]:
"""Add new request to the AsyncLLM."""
# 1) Create a new output queue for the request.
# Create a new output queue for the request.
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
# 2) Fan out child requests (for n>1)
parent_req = ParentRequest.from_params(request_id, params)
# Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
await self._add_request(request, None, 0, queue)
return queue
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, params)
for idx in range(n):
if parent_req is not None:
request_id, params = parent_req.get_child_info(idx)
# 3) Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
# 4) Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, parent_req, idx, queue)
# 5) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
logger.info("Added request %s.", request_id)
request_id, params = parent_request.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
await self._add_request(child_request, parent_request, idx, queue)
return queue
async def _add_request(self, request: EngineCoreRequest,
parent_req: Optional[ParentRequest], index: int,
queue: asyncio.Queue[RequestOutput]):
# Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, parent_req, index, queue)
# Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Mapping
from copy import copy
from typing import Optional, Union
from typing_extensions import TypeVar
@@ -179,25 +180,34 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
# 1) Fan out child requests (for n>1)
parent_req = ParentRequest.from_params(request_id, params)
# Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1
for idx in range(n):
if parent_req is not None:
request_id, params = parent_req.get_child_info(idx)
# 2) Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
# 3) Make a new RequestState and queue.
self.output_processor.add_request(request, parent_req, idx)
# 3) Add the request to EngineCore.
if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, None, 0)
# Add the request to EngineCore.
self.engine_core.add_request(request)
return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, parent_req, idx)
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput]:

View File

@@ -1,10 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
from copy import copy
from typing import Optional, Union
from typing import Optional
from vllm.outputs import CompletionOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.metrics.stats import IterationStats
@@ -43,16 +42,6 @@ class ParentRequest:
self.max_num_generation_tokens = 0
self.cached_child_sampling_params = None
@classmethod
def from_params(
cls,
request_id: str,
params: Union[SamplingParams, PoolingParams],
) -> Optional['ParentRequest']:
if not isinstance(params, SamplingParams) or params.n == 1:
return None
return cls(request_id, params)
def _get_child_sampling_params(
self,
index: int,

View File

@@ -173,7 +173,6 @@ class Processor:
# 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash,