[core] move parallel sampling out from vllm core (#9302)

This commit is contained in:
youkaichao
2024-10-21 17:31:44 -07:00
committed by GitHub
parent ef7faad1b8
commit 76a5e13270
4 changed files with 222 additions and 29 deletions

View File

@@ -44,8 +44,10 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceGroupOutput, SequenceStatus)
ParallelSampleSequenceGroup, Sequence,
SequenceGroup, SequenceGroupBase,
SequenceGroupMetadata, SequenceGroupOutput,
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
@@ -474,6 +476,8 @@ class LLMEngine:
),
))
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@@ -642,7 +646,10 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> None:
) -> SequenceGroup:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
@@ -696,6 +703,8 @@ class LLMEngine:
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)
return seq_group
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
@@ -711,7 +720,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
) -> Optional[SequenceGroup]:
...
@overload
@@ -725,7 +734,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
) -> Optional[SequenceGroup]:
...
@deprecate_kwargs(
@@ -744,7 +753,7 @@ class LLMEngine:
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
) -> Optional[SequenceGroup]:
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
@@ -788,6 +797,22 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
prompt=prompt,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
inputs=inputs,
)
return None
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
@@ -818,7 +843,7 @@ class LLMEngine:
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")
self._add_processed_request(
return self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
@@ -1135,7 +1160,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)
@@ -1175,7 +1202,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)
@@ -1194,7 +1223,10 @@ class LLMEngine:
continue
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs,
)
if request_output:
ctx.request_outputs.append(request_output)