[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) (#10980)

Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
afeldman-nm
2025-02-24 11:29:41 -05:00
committed by GitHub
parent 444b0f0f62
commit befc402d34
5 changed files with 640 additions and 8 deletions

View File

@@ -24,6 +24,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils import cdiv, kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
@@ -170,7 +171,7 @@ class AsyncLLM(EngineClient):
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
async def _generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
@@ -241,6 +242,30 @@ class AsyncLLM(EngineClient):
await self.abort(request_id)
raise
def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
kwargs = dict(prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority)
if sampling_params.n is None or sampling_params.n == 1:
return self._generate(**kwargs)
else:
# Special handling for parallel sampling requests
return generate_parallel_sampling_async(generate=self._generate,
**kwargs)
async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""