[V1] Refactor parallel sampling support (#13774)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin
2025-03-03 16:15:27 +00:00
committed by GitHub
parent f35f8e2242
commit 4167252eaf
5 changed files with 198 additions and 461 deletions

View File

@@ -25,7 +25,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils import cdiv, kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
@@ -145,25 +145,30 @@ class AsyncLLM(EngineClient):
"""Add new request to the AsyncLLM."""
# 1) Create a new output queue for the request.
if self.output_processor.is_request_active(request_id):
raise ValueError(f"Request id {request_id} already running.")
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
# 2) Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
# 2) Fan out child requests (for n>1)
parent_req = ParentRequest.from_params(request_id, params)
n = params.n if isinstance(params, SamplingParams) else 1
for idx in range(n):
if parent_req is not None:
request_id, params = parent_req.get_child_info(idx)
# 3) Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, queue)
# 3) Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
# 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
# 4) Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, parent_req, idx, queue)
if self.log_requests:
logger.info("Added request %s.", request_id)
# 5) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
logger.info("Added request %s.", request_id)
return queue
@@ -172,7 +177,7 @@ class AsyncLLM(EngineClient):
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def _generate(
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
@@ -243,30 +248,6 @@ class AsyncLLM(EngineClient):
await self.abort(request_id)
raise
def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
kwargs = dict(prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority)
if sampling_params.n is None or sampling_params.n == 1:
return self._generate(**kwargs)
else:
# Special handling for parallel sampling requests
return generate_parallel_sampling_async(generate=self._generate,
**kwargs)
async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""