[V1] Refactor parallel sampling support (#13774)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin
2025-03-03 16:15:27 +00:00
committed by GitHub
parent f35f8e2242
commit 4167252eaf
5 changed files with 198 additions and 461 deletions

View File

@@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer_group import (
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
@@ -50,9 +50,6 @@ class LLMEngine:
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
# Bookkeeping for parallel sampling requests
self.parallel_manager = SyncParallelSamplingManager()
# important: init dp group before init the engine_core
self.parallel_config = vllm_config.parallel_config
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
@@ -120,8 +117,7 @@ class LLMEngine:
multiprocess_mode=enable_multiprocessing)
def get_num_unfinished_requests(self) -> int:
return self.parallel_manager.get_num_unfinished_requests(
self.output_processor.get_num_unfinished_requests())
return self.output_processor.get_num_unfinished_requests()
def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
@@ -157,48 +153,25 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Add request."""
kwargs = dict(request_id=request_id,
prompt=prompt,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority)
# Handle parallel sampling requests differently.
if params is None or isinstance(params,
PoolingParams) or params.n == 1:
self._add_request(**kwargs)
else:
# Special handling for parallel sampling requests
self.parallel_manager.add_request_parallel_sampling(
add_request=self._add_request, **kwargs)
# 1) Fan out child requests (for n>1)
parent_req = ParentRequest.from_params(request_id, params)
n = params.n if isinstance(params, SamplingParams) else 1
for idx in range(n):
if parent_req is not None:
request_id, params = parent_req.get_child_info(idx)
def _add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Add request, `n=1`"""
# 1) Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
# 2) Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
# 2) Make a new RequestState and queue.
self.output_processor.add_request(request)
# 3) Make a new RequestState and queue.
self.output_processor.add_request(request, parent_req, idx)
# 3) Add the request to EngineCore.
self.engine_core.add_request(request)
# 3) Add the request to EngineCore.
self.engine_core.add_request(request)
def step(self) -> list[RequestOutput]:
@@ -217,10 +190,7 @@ class LLMEngine:
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
request_outputs = processed_outputs.request_outputs
# 4) Process unfinished parallel sampling requests
return self.parallel_manager.step(request_outputs)
return processed_outputs.request_outputs
def get_model_config(self):
return self.model_config