[core] move parallel sampling out from vllm core (#9302)
This commit is contained in:
122
vllm/sequence.py
122
vllm/sequence.py
@@ -4,7 +4,7 @@ import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, reduce
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
@@ -17,7 +17,7 @@ from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -1401,3 +1401,121 @@ class ExecuteModelRequest(
|
||||
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||
if self.last_sampled_token_ids is not None else None,
|
||||
async_callback=self.async_callback)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGroupBase:
|
||||
group_id: str # the original request id before splitting
|
||||
|
||||
assembled_seq_group: Optional[SequenceGroup] = None
|
||||
|
||||
# seq id to a unique index inside this group
|
||||
seq_id_to_index: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# seq ids to be finished
|
||||
to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)
|
||||
|
||||
# seq id to finished sequences
|
||||
finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)
|
||||
|
||||
streaming: bool = False
|
||||
|
||||
output_produced: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_request(request_id: str, engine, params, *args, **kwargs):
|
||||
"""When we are ready to add a request with request_id and params
|
||||
into the engine, we can split the request into multiple requests.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def finish_seq(self, seq: SequenceGroup):
|
||||
"""The sequence `seq` finishes, we should record the information.
|
||||
"""
|
||||
del self.to_be_finished[seq.request_id]
|
||||
self.finished_reqs[seq.request_id] = seq
|
||||
|
||||
def maybe_assemble_group(
|
||||
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
|
||||
"""Assemble the sequence group, for producing the final
|
||||
output, or adding request in the engine again.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ParallelSampleSequenceGroup(SequenceGroupBase):
|
||||
|
||||
@staticmethod
|
||||
def add_request(request_id: str, engine, params, **kwargs):
|
||||
original_params = params
|
||||
params = copy.deepcopy(original_params)
|
||||
params.n = 1
|
||||
group = ParallelSampleSequenceGroup(request_id)
|
||||
seqs = []
|
||||
for i in range(original_params.n):
|
||||
request_id_i = f"{request_id}_parallel_sample_{i}"
|
||||
group.seq_id_to_index[request_id_i] = i
|
||||
seq_group = engine.add_request(
|
||||
request_id_i,
|
||||
params=params,
|
||||
**kwargs,
|
||||
) # type: ignore
|
||||
assert seq_group is not None
|
||||
engine.seq_id_to_seq_group[request_id_i] = group
|
||||
group.to_be_finished[request_id_i] = seq_group
|
||||
seqs.append(seq_group.seqs[0])
|
||||
|
||||
# for parallel sampling, the `assembled_seq_group` is always
|
||||
# available, since we have all the sequences ready, and they
|
||||
# will not change.
|
||||
group.assembled_seq_group = SequenceGroup(
|
||||
request_id=request_id,
|
||||
seqs=seqs,
|
||||
arrival_time=seq_group.arrival_time,
|
||||
sampling_params=original_params,
|
||||
lora_request=seq_group.lora_request,
|
||||
embeddings=seq_group.embeddings,
|
||||
pooling_params=seq_group.pooling_params,
|
||||
encoder_seq=seq_group.encoder_seq,
|
||||
trace_headers=seq_group.trace_headers,
|
||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||
priority=seq_group.priority,
|
||||
)
|
||||
|
||||
group.streaming = params.output_kind == RequestOutputKind.DELTA
|
||||
group.output_produced = False
|
||||
|
||||
def maybe_assemble_group(
|
||||
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
|
||||
|
||||
# in the streaming mode, we will return the assembled sequence
|
||||
# for the first sequence, and then return None for the rest of
|
||||
# sequences
|
||||
if self.streaming:
|
||||
if self.seq_id_to_index[seq_group.request_id] == 0:
|
||||
return self.assembled_seq_group
|
||||
return None
|
||||
|
||||
# in the non-streaming mode, we will return the assembled sequence
|
||||
# once after all sequences finish, and then return None for the
|
||||
# rest of the time
|
||||
|
||||
if len(self.to_be_finished) > 0:
|
||||
return None
|
||||
|
||||
assert self.assembled_seq_group is not None
|
||||
params = self.assembled_seq_group.sampling_params
|
||||
assert isinstance(params, SamplingParams)
|
||||
if not self.output_produced:
|
||||
self.output_produced = True
|
||||
if params._real_n is not None:
|
||||
# Get the top-n sequences.
|
||||
n = params._real_n or params.n
|
||||
seqs = self.assembled_seq_group.seqs
|
||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
self.assembled_seq_group.seqs = top_n_seqs
|
||||
return self.assembled_seq_group
|
||||
if self.output_produced:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user