[V1] Refactor parallel sampling support (#13774)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer_group import (
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
@@ -50,9 +50,6 @@ class LLMEngine:
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
# Bookkeeping for parallel sampling requests
|
||||
self.parallel_manager = SyncParallelSamplingManager()
|
||||
|
||||
# important: init dp group before init the engine_core
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
|
||||
@@ -120,8 +117,7 @@ class LLMEngine:
|
||||
multiprocess_mode=enable_multiprocessing)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.parallel_manager.get_num_unfinished_requests(
|
||||
self.output_processor.get_num_unfinished_requests())
|
||||
return self.output_processor.get_num_unfinished_requests()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
has_unfinished = self.output_processor.has_unfinished_requests()
|
||||
@@ -157,48 +153,25 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Add request."""
|
||||
kwargs = dict(request_id=request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority)
|
||||
# Handle parallel sampling requests differently.
|
||||
if params is None or isinstance(params,
|
||||
PoolingParams) or params.n == 1:
|
||||
self._add_request(**kwargs)
|
||||
else:
|
||||
# Special handling for parallel sampling requests
|
||||
self.parallel_manager.add_request_parallel_sampling(
|
||||
add_request=self._add_request, **kwargs)
|
||||
# 1) Fan out child requests (for n>1)
|
||||
parent_req = ParentRequest.from_params(request_id, params)
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
for idx in range(n):
|
||||
if parent_req is not None:
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Add request, `n=1`"""
|
||||
# 1) Process raw inputs into the request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
# 2) Process raw inputs into the request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
# 2) Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request)
|
||||
# 3) Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, parent_req, idx)
|
||||
|
||||
# 3) Add the request to EngineCore.
|
||||
self.engine_core.add_request(request)
|
||||
# 3) Add the request to EngineCore.
|
||||
self.engine_core.add_request(request)
|
||||
|
||||
def step(self) -> list[RequestOutput]:
|
||||
|
||||
@@ -217,10 +190,7 @@ class LLMEngine:
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
request_outputs = processed_outputs.request_outputs
|
||||
|
||||
# 4) Process unfinished parallel sampling requests
|
||||
return self.parallel_manager.step(request_outputs)
|
||||
return processed_outputs.request_outputs
|
||||
|
||||
def get_model_config(self):
|
||||
return self.model_config
|
||||
|
||||
Reference in New Issue
Block a user