[core] simplify seq group code (#9569)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
102
vllm/sequence.py
102
vllm/sequence.py
@@ -681,6 +681,7 @@ class SequenceGroup:
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs = seqs
|
||||
self.first_seq = seqs[0]
|
||||
self.arrival_time = arrival_time
|
||||
self.is_single_seq = len(seqs) == 1
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
@@ -705,15 +706,11 @@ class SequenceGroup:
|
||||
|
||||
@property
|
||||
def prompt(self) -> Optional[str]:
|
||||
# All sequences in the group should have the same prompt.
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return self.seqs[0].prompt
|
||||
return self.first_seq.prompt
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
# All sequences in the group should have the same prompt.
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return self.seqs[0].prompt_token_ids
|
||||
return self.first_seq.prompt_token_ids
|
||||
|
||||
@property
|
||||
def encoder_prompt(self) -> Optional[str]:
|
||||
@@ -733,17 +730,11 @@ class SequenceGroup:
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
# All sequences in the group should have the same multi-modal data.
|
||||
# We use the multi-modal data of an arbitrary sequence.
|
||||
return self.seqs[0].multi_modal_data
|
||||
return self.first_seq.multi_modal_data
|
||||
|
||||
@property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
# As with multi-modal data, all sequences in the group should have the
|
||||
# same processor kwargs (i.e., mm_processor_kwargs are optionally
|
||||
# provided per request; note that are independent of whether the model
|
||||
# decoder-only or an encoder-decoder).
|
||||
return self.seqs[0].mm_processor_kwargs
|
||||
return self.first_seq.mm_processor_kwargs
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
@@ -808,7 +799,7 @@ class SequenceGroup:
|
||||
# in TPOT, rather than recalculating TTFT (since from the )
|
||||
# POV of the user, there is simply a long generation delay.
|
||||
if (self.metrics.first_token_time is None
|
||||
and self.seqs[0].get_output_len() == 1):
|
||||
and self.first_seq.get_output_len() == 1):
|
||||
self.metrics.first_token_time = time
|
||||
|
||||
def maybe_set_first_scheduled_time(self, time: float) -> None:
|
||||
@@ -825,18 +816,7 @@ class SequenceGroup:
|
||||
def get_max_num_running_seqs(self) -> int:
|
||||
"""The maximum number of sequences running in parallel in the remaining
|
||||
lifetime of the request."""
|
||||
if self.sampling_params:
|
||||
n = self.sampling_params.n
|
||||
assert isinstance(n, int)
|
||||
if n > self.num_seqs():
|
||||
# At prompt stage, the sequence group is not yet filled up
|
||||
# and only have one sequence running. However, in the
|
||||
# generation stage, we will have `n` sequences
|
||||
# running.
|
||||
return n
|
||||
# At sampling stages, return the number of actual sequences
|
||||
# that are not finished yet.
|
||||
return self.num_unfinished_seqs()
|
||||
return 0 if self.first_seq.is_finished() else 1
|
||||
|
||||
def get_seqs(
|
||||
self,
|
||||
@@ -845,10 +825,7 @@ class SequenceGroup:
|
||||
if status is None:
|
||||
return self.seqs
|
||||
|
||||
if self.is_single_seq:
|
||||
return self.seqs if self.seqs[0].status == status else []
|
||||
|
||||
return [seq for seq in self.seqs if seq.status == status]
|
||||
return self.seqs if self.first_seq.status == status else []
|
||||
|
||||
def is_encoder_decoder(self) -> bool:
|
||||
return self.encoder_seq is not None
|
||||
@@ -856,29 +833,20 @@ class SequenceGroup:
|
||||
def get_encoder_seq(self) -> Optional[Sequence]:
|
||||
return self.encoder_seq
|
||||
|
||||
def get_unfinished_seqs(self) -> List[Sequence]:
|
||||
if self.is_single_seq:
|
||||
return self.seqs if not self.seqs[0].is_finished() else []
|
||||
|
||||
return [seq for seq in self.seqs if not seq.is_finished()]
|
||||
|
||||
def get_finished_seqs(self) -> List[Sequence]:
|
||||
if self.is_single_seq:
|
||||
return self.seqs if self.seqs[0].is_finished() else []
|
||||
|
||||
return [seq for seq in self.seqs if seq.is_finished()]
|
||||
return self.seqs if self.first_seq.is_finished() else []
|
||||
|
||||
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
||||
"""Update number of tokens computed so far."""
|
||||
for seq in self.seqs:
|
||||
if not seq.is_finished():
|
||||
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
||||
seq = self.first_seq
|
||||
if not seq.is_finished():
|
||||
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
||||
|
||||
def get_num_uncomputed_tokens(self) -> int:
|
||||
num_uncomputed_tokens = 0
|
||||
for seq in self.seqs:
|
||||
if not seq.is_finished():
|
||||
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
||||
seq = self.first_seq
|
||||
if not seq.is_finished():
|
||||
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
||||
return num_uncomputed_tokens
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
@@ -892,46 +860,14 @@ class SequenceGroup:
|
||||
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def num_unfinished_seqs(self) -> int:
|
||||
if self.is_single_seq:
|
||||
return 1 if not self.seqs[0].is_finished() else 0
|
||||
|
||||
return len(self.get_unfinished_seqs())
|
||||
|
||||
def num_finished_seqs(self) -> int:
|
||||
if self.is_single_seq:
|
||||
return 1 if self.seqs[0].is_finished() else 0
|
||||
|
||||
return len(self.get_finished_seqs())
|
||||
|
||||
def find(self, seq_id: int) -> Sequence:
|
||||
if seq_id not in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
return self.seqs_dict[seq_id]
|
||||
|
||||
def add(self, seq: Sequence) -> None:
|
||||
if seq.seq_id in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
||||
self.seqs_dict[seq.seq_id] = seq
|
||||
self.seqs.append(seq)
|
||||
self.is_single_seq = len(self.seqs) == 1
|
||||
|
||||
def remove(self, seq_id: int) -> None:
|
||||
seq = self.seqs_dict.pop(seq_id, None)
|
||||
if seq is None:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
self.seqs.remove(seq)
|
||||
self.is_single_seq = len(self.seqs) == 1
|
||||
return 1 if self.first_seq.is_finished() else 0
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
if self.is_single_seq:
|
||||
return self.seqs[0].is_finished()
|
||||
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
return self.first_seq.is_finished()
|
||||
|
||||
def is_prefill(self) -> bool:
|
||||
# Every sequence should be in the same stage.
|
||||
return self.seqs[0].is_prefill()
|
||||
return self.first_seq.is_prefill()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||
@@ -1455,7 +1391,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
||||
for i in range(original_params.n):
|
||||
request_id_i = f"{request_id}_parallel_sample_{i}"
|
||||
group.seq_id_to_index[request_id_i] = i
|
||||
seq_group = engine.add_request(
|
||||
seq_group = engine._add_processed_request(
|
||||
request_id_i,
|
||||
params=params,
|
||||
**kwargs,
|
||||
|
||||
Reference in New Issue
Block a user