[V1] Avoid redundant input processing in n>1 case (#14985)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-03-20 22:24:10 -07:00
committed by GitHub
parent 7297941b38
commit da6ea29f7a
13 changed files with 85 additions and 145 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Mapping
from copy import copy
from typing import Optional, Union
from typing_extensions import TypeVar
@@ -179,25 +180,34 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
# 1) Fan out child requests (for n>1)
parent_req = ParentRequest.from_params(request_id, params)
# Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1
for idx in range(n):
if parent_req is not None:
request_id, params = parent_req.get_child_info(idx)
# 2) Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
# 3) Make a new RequestState and queue.
self.output_processor.add_request(request, parent_req, idx)
# 3) Add the request to EngineCore.
if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, None, 0)
# Add the request to EngineCore.
self.engine_core.add_request(request)
return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, parent_req, idx)
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput]: