[V1] Refactor parallel sampling support (#13774)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@@ -25,7 +25,7 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import cdiv, kill_process_tree
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
|
||||
@@ -145,25 +145,30 @@ class AsyncLLM(EngineClient):
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
# 1) Create a new output queue for the request.
|
||||
if self.output_processor.is_request_active(request_id):
|
||||
raise ValueError(f"Request id {request_id} already running.")
|
||||
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
|
||||
|
||||
# 2) Convert Input --> Request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
# 2) Fan out child requests (for n>1)
|
||||
parent_req = ParentRequest.from_params(request_id, params)
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
for idx in range(n):
|
||||
if parent_req is not None:
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
|
||||
# 3) Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, queue)
|
||||
# 3) Convert Input --> Request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
# 4) Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
# 4) Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, parent_req, idx, queue)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request_id)
|
||||
# 5) Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request_id)
|
||||
|
||||
return queue
|
||||
|
||||
@@ -172,7 +177,7 @@ class AsyncLLM(EngineClient):
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
# and so we don't need multiple streams which then get
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def _generate(
|
||||
async def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
@@ -243,30 +248,6 @@ class AsyncLLM(EngineClient):
|
||||
await self.abort(request_id)
|
||||
raise
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
kwargs = dict(prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority)
|
||||
if sampling_params.n is None or sampling_params.n == 1:
|
||||
return self._generate(**kwargs)
|
||||
else:
|
||||
# Special handling for parallel sampling requests
|
||||
return generate_parallel_sampling_async(generate=self._generate,
|
||||
**kwargs)
|
||||
|
||||
async def _run_output_handler(self):
|
||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user