[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) (#10980)

Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
afeldman-nm
2025-02-24 11:29:41 -05:00
committed by GitHub
parent 444b0f0f62
commit befc402d34
5 changed files with 640 additions and 8 deletions

View File

@@ -21,6 +21,7 @@ from vllm.transformers_utils.tokenizer_group import (
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
@@ -48,6 +49,9 @@ class LLMEngine:
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
# Bookkeeping for parallel sampling requests
self.parallel_manager = SyncParallelSamplingManager()
# important: init dp group before init the engine_core
self.parallel_config = vllm_config.parallel_config
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
@@ -115,7 +119,8 @@ class LLMEngine:
multiprocess_mode=enable_multiprocessing)
def get_num_unfinished_requests(self) -> int:
return self.output_processor.get_num_unfinished_requests()
return self.parallel_manager.get_num_unfinished_requests(
self.output_processor.get_num_unfinished_requests())
def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
@@ -151,7 +156,36 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Add request."""
kwargs = dict(request_id=request_id,
prompt=prompt,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority)
# Handle parallel sampling requests differently.
if params is None or isinstance(params,
PoolingParams) or params.n == 1:
self._add_request(**kwargs)
else:
# Special handling for parallel sampling requests
self.parallel_manager.add_request_parallel_sampling(
add_request=self._add_request, **kwargs)
def _add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Add request, `n=1`"""
# 1) Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
@@ -182,7 +216,10 @@ class LLMEngine:
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
return processed_outputs.request_outputs
request_outputs = processed_outputs.request_outputs
# 4) Process unfinished parallel sampling requests
return self.parallel_manager.step(request_outputs)
def get_model_config(self):
return self.model_config