[core] move parallel sampling out from vllm core (#9302)
This commit is contained in:
@@ -44,8 +44,10 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceStatus)
|
||||
ParallelSampleSequenceGroup, Sequence,
|
||||
SequenceGroup, SequenceGroupBase,
|
||||
SequenceGroupMetadata, SequenceGroupOutput,
|
||||
SequenceStatus)
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
@@ -474,6 +476,8 @@ class LLMEngine:
|
||||
),
|
||||
))
|
||||
|
||||
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
"""Initialize the KV cache in the worker(s).
|
||||
|
||||
@@ -642,7 +646,10 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
) -> SequenceGroup:
|
||||
"""Add a processed request to the engine's request pool.
|
||||
return the created sequence group.
|
||||
"""
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
@@ -696,6 +703,8 @@ class LLMEngine:
|
||||
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
|
||||
min_cost_scheduler.add_seq_group(seq_group)
|
||||
|
||||
return seq_group
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
|
||||
@@ -711,7 +720,7 @@ class LLMEngine:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
) -> Optional[SequenceGroup]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@@ -725,7 +734,7 @@ class LLMEngine:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
) -> Optional[SequenceGroup]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
@@ -744,7 +753,7 @@ class LLMEngine:
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> None:
|
||||
) -> Optional[SequenceGroup]:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
The request is added to the request pool and will be processed by the
|
||||
@@ -788,6 +797,22 @@ class LLMEngine:
|
||||
>>> # continue the request processing
|
||||
>>> ...
|
||||
"""
|
||||
|
||||
if isinstance(params, SamplingParams) and params.n > 1:
|
||||
ParallelSampleSequenceGroup.add_request(
|
||||
request_id,
|
||||
self,
|
||||
params,
|
||||
prompt=prompt,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
inputs=inputs,
|
||||
)
|
||||
return None
|
||||
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert prompt is not None and params is not None
|
||||
@@ -818,7 +843,7 @@ class LLMEngine:
|
||||
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
|
||||
"mm_processor_kwargs")
|
||||
|
||||
self._add_processed_request(
|
||||
return self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
params=params,
|
||||
@@ -1135,7 +1160,9 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
request_output = RequestOutputFactory.create(
|
||||
seq_group, use_cache=self.use_cached_outputs)
|
||||
seq_group,
|
||||
self.seq_id_to_seq_group,
|
||||
use_cache=self.use_cached_outputs)
|
||||
if request_output:
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
@@ -1175,7 +1202,9 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
request_output = RequestOutputFactory.create(
|
||||
seq_group, use_cache=self.use_cached_outputs)
|
||||
seq_group,
|
||||
self.seq_id_to_seq_group,
|
||||
use_cache=self.use_cached_outputs)
|
||||
if request_output:
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
@@ -1194,7 +1223,10 @@ class LLMEngine:
|
||||
continue
|
||||
|
||||
request_output = RequestOutputFactory.create(
|
||||
seq_group, use_cache=self.use_cached_outputs)
|
||||
seq_group,
|
||||
self.seq_id_to_seq_group,
|
||||
use_cache=self.use_cached_outputs,
|
||||
)
|
||||
if request_output:
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user