[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) (#10980)
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from vllm.transformers_utils.tokenizer_group import (
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
@@ -48,6 +49,9 @@ class LLMEngine:
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
# Bookkeeping for parallel sampling requests
|
||||
self.parallel_manager = SyncParallelSamplingManager()
|
||||
|
||||
# important: init dp group before init the engine_core
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
|
||||
@@ -115,7 +119,8 @@ class LLMEngine:
|
||||
multiprocess_mode=enable_multiprocessing)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.output_processor.get_num_unfinished_requests()
|
||||
return self.parallel_manager.get_num_unfinished_requests(
|
||||
self.output_processor.get_num_unfinished_requests())
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
has_unfinished = self.output_processor.has_unfinished_requests()
|
||||
@@ -151,7 +156,36 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Add request."""
|
||||
kwargs = dict(request_id=request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority)
|
||||
# Handle parallel sampling requests differently.
|
||||
if params is None or isinstance(params,
|
||||
PoolingParams) or params.n == 1:
|
||||
self._add_request(**kwargs)
|
||||
else:
|
||||
# Special handling for parallel sampling requests
|
||||
self.parallel_manager.add_request_parallel_sampling(
|
||||
add_request=self._add_request, **kwargs)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Add request, `n=1`"""
|
||||
# 1) Process raw inputs into the request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
@@ -182,7 +216,10 @@ class LLMEngine:
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
return processed_outputs.request_outputs
|
||||
request_outputs = processed_outputs.request_outputs
|
||||
|
||||
# 4) Process unfinished parallel sampling requests
|
||||
return self.parallel_manager.step(request_outputs)
|
||||
|
||||
def get_model_config(self):
|
||||
return self.model_config
|
||||
|
||||
Reference in New Issue
Block a user