[core] move parallel sampling out from vllm core (#9302)

This commit is contained in:
youkaichao
2024-10-21 17:31:44 -07:00
committed by GitHub
parent ef7faad1b8
commit 76a5e13270
4 changed files with 222 additions and 29 deletions

View File

@@ -4,7 +4,7 @@ import enum
from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
@@ -17,7 +17,7 @@ from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if TYPE_CHECKING:
@@ -1401,3 +1401,121 @@ class ExecuteModelRequest(
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback)
@dataclass
class SequenceGroupBase:
group_id: str # the original request id before splitting
assembled_seq_group: Optional[SequenceGroup] = None
# seq id to a unique index inside this group
seq_id_to_index: Dict[str, int] = field(default_factory=dict)
# seq ids to be finished
to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)
# seq id to finished sequences
finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)
streaming: bool = False
output_produced: bool = False
@staticmethod
def add_request(request_id: str, engine, params, *args, **kwargs):
"""When we are ready to add a request with request_id and params
into the engine, we can split the request into multiple requests.
"""
raise NotImplementedError
def finish_seq(self, seq: SequenceGroup):
"""The sequence `seq` finishes, we should record the information.
"""
del self.to_be_finished[seq.request_id]
self.finished_reqs[seq.request_id] = seq
def maybe_assemble_group(
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
"""Assemble the sequence group, for producing the final
output, or adding request in the engine again.
"""
raise NotImplementedError
class ParallelSampleSequenceGroup(SequenceGroupBase):
@staticmethod
def add_request(request_id: str, engine, params, **kwargs):
original_params = params
params = copy.deepcopy(original_params)
params.n = 1
group = ParallelSampleSequenceGroup(request_id)
seqs = []
for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i
seq_group = engine.add_request(
request_id_i,
params=params,
**kwargs,
) # type: ignore
assert seq_group is not None
engine.seq_id_to_seq_group[request_id_i] = group
group.to_be_finished[request_id_i] = seq_group
seqs.append(seq_group.seqs[0])
# for parallel sampling, the `assembled_seq_group` is always
# available, since we have all the sequences ready, and they
# will not change.
group.assembled_seq_group = SequenceGroup(
request_id=request_id,
seqs=seqs,
arrival_time=seq_group.arrival_time,
sampling_params=original_params,
lora_request=seq_group.lora_request,
embeddings=seq_group.embeddings,
pooling_params=seq_group.pooling_params,
encoder_seq=seq_group.encoder_seq,
trace_headers=seq_group.trace_headers,
prompt_adapter_request=seq_group.prompt_adapter_request,
priority=seq_group.priority,
)
group.streaming = params.output_kind == RequestOutputKind.DELTA
group.output_produced = False
def maybe_assemble_group(
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
# in the streaming mode, we will return the assembled sequence
# for the first sequence, and then return None for the rest of
# sequences
if self.streaming:
if self.seq_id_to_index[seq_group.request_id] == 0:
return self.assembled_seq_group
return None
# in the non-streaming mode, we will return the assembled sequence
# once after all sequences finish, and then return None for the
# rest of the time
if len(self.to_be_finished) > 0:
return None
assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams)
if not self.output_produced:
self.output_produced = True
if params._real_n is not None:
# Get the top-n sequences.
n = params._real_n or params.n
seqs = self.assembled_seq_group.seqs
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
self.assembled_seq_group.seqs = top_n_seqs
return self.assembled_seq_group
if self.output_produced:
return None