[V1] Avoid redundant input processing in n>1 case (#14985)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from copy import copy
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -25,6 +26,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, cdiv, kill_process_tree
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
@@ -177,34 +179,45 @@ class AsyncLLM(EngineClient):
|
||||
) -> asyncio.Queue[RequestOutput]:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
# 1) Create a new output queue for the request.
|
||||
# Create a new output queue for the request.
|
||||
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
|
||||
|
||||
# 2) Fan out child requests (for n>1)
|
||||
parent_req = ParentRequest.from_params(request_id, params)
|
||||
# Convert Input --> Request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
if n == 1:
|
||||
await self._add_request(request, None, 0, queue)
|
||||
return queue
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_request = ParentRequest(request_id, params)
|
||||
for idx in range(n):
|
||||
if parent_req is not None:
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
|
||||
# 3) Convert Input --> Request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
# 4) Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, parent_req, idx, queue)
|
||||
|
||||
# 5) Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request_id)
|
||||
|
||||
request_id, params = parent_request.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
await self._add_request(child_request, parent_request, idx, queue)
|
||||
return queue
|
||||
|
||||
async def _add_request(self, request: EngineCoreRequest,
|
||||
parent_req: Optional[ParentRequest], index: int,
|
||||
queue: asyncio.Queue[RequestOutput]):
|
||||
|
||||
# Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, parent_req, index, queue)
|
||||
|
||||
# Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
# TODO: we should support multiple prompts in one call, as you
|
||||
# can do with LLM.generate. So that for multi-prompt completion
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from typing import Optional, Union
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
@@ -179,25 +180,34 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
# 1) Fan out child requests (for n>1)
|
||||
parent_req = ParentRequest.from_params(request_id, params)
|
||||
# Process raw inputs into the request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
for idx in range(n):
|
||||
if parent_req is not None:
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
|
||||
# 2) Process raw inputs into the request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
# 3) Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, parent_req, idx)
|
||||
|
||||
# 3) Add the request to EngineCore.
|
||||
if n == 1:
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, None, 0)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(request)
|
||||
return
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_req = ParentRequest(request_id, params)
|
||||
for idx in range(n):
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(child_request, parent_req, idx)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(child_request)
|
||||
|
||||
def step(self) -> list[RequestOutput]:
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from copy import copy
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
@@ -43,16 +42,6 @@ class ParentRequest:
|
||||
self.max_num_generation_tokens = 0
|
||||
self.cached_child_sampling_params = None
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
request_id: str,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
) -> Optional['ParentRequest']:
|
||||
if not isinstance(params, SamplingParams) or params.n == 1:
|
||||
return None
|
||||
return cls(request_id, params)
|
||||
|
||||
def _get_child_sampling_params(
|
||||
self,
|
||||
index: int,
|
||||
|
||||
@@ -173,7 +173,6 @@ class Processor:
|
||||
# 3. Apply prompt adapter to prompt token ids if one exists.
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=self.use_hash,
|
||||
|
||||
Reference in New Issue
Block a user