[Quality] Add code formatter and linter (#326)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
"""Sequence and its related classes."""
|
||||
import copy
|
||||
import enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
@@ -7,6 +8,7 @@ from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class SequenceStatus(enum.Enum):
|
||||
"""Status of a sequence."""
|
||||
WAITING = enum.auto()
|
||||
RUNNING = enum.auto()
|
||||
SWAPPED = enum.auto()
|
||||
@@ -21,7 +23,7 @@ class SequenceStatus(enum.Enum):
|
||||
SequenceStatus.FINISHED_STOPPED,
|
||||
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
||||
SequenceStatus.FINISHED_ABORTED,
|
||||
SequenceStatus.FINISHED_IGNORED
|
||||
SequenceStatus.FINISHED_IGNORED,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -40,6 +42,17 @@ class SequenceStatus(enum.Enum):
|
||||
|
||||
|
||||
class SequenceData:
|
||||
"""Data associated with a sequence.
|
||||
|
||||
|
||||
Args:
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
|
||||
Attributes:
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
output_token_ids: The token IDs of the output.
|
||||
cumulative_logprob: The cumulative log probability of the output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -75,6 +88,15 @@ class SequenceData:
|
||||
|
||||
|
||||
class Sequence:
|
||||
"""Stores the data, status, and block information of a sequence.
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
prompt: The prompt of the sequence.
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
block_size: The block size of the sequence. Should be the same as the
|
||||
block size used by the block manager and cache engine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -149,19 +171,27 @@ class Sequence:
|
||||
def is_finished(self) -> bool:
|
||||
return SequenceStatus.is_finished(self.status)
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> None:
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
def fork(self, child_seq: "Sequence") -> None:
|
||||
child_seq.logical_token_blocks = copy.deepcopy(
|
||||
self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
child_seq.data = copy.deepcopy(self.data)
|
||||
return None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'Sequence(seq_id={self.seq_id}, '
|
||||
f'status={self.status.name}, '
|
||||
f'num_blocks={len(self.logical_token_blocks)})')
|
||||
return (f"Sequence(seq_id={self.seq_id}, "
|
||||
f"status={self.status.name}, "
|
||||
f"num_blocks={len(self.logical_token_blocks)})")
|
||||
|
||||
|
||||
class SequenceGroup:
|
||||
"""A group of sequences that are generated from the same prompt.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
seqs: The list of sequences.
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
arrival_time: The arrival time of the request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -191,7 +221,7 @@ class SequenceGroup:
|
||||
for seq in self.seqs:
|
||||
if seq.seq_id == seq_id:
|
||||
return seq
|
||||
raise ValueError(f'Sequence {seq_id} not found.')
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
@@ -203,14 +233,25 @@ class SequenceGroup:
|
||||
|
||||
|
||||
class SequenceGroupMetadata:
|
||||
"""Metadata for a sequence group. Used to create `InputMetadata`.
|
||||
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
is_prompt: Whether the request is at prompt stage.
|
||||
seq_data: The sequence data. (Seq id -> sequence data)
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
block_tables: The block tables. (Seq id -> list of physical block
|
||||
numbers)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
is_prompt: bool,
|
||||
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
|
||||
seq_data: Dict[int, SequenceData],
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
|
||||
block_tables: Dict[int, List[int]],
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.is_prompt = is_prompt
|
||||
@@ -220,13 +261,23 @@ class SequenceGroupMetadata:
|
||||
|
||||
|
||||
class SequenceOutputs:
|
||||
"""The model output associated with a sequence.
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
parent_seq_id: The ID of the parent sequence (for forking in beam
|
||||
search).
|
||||
output_token: The output token ID.
|
||||
logprobs: The logprobs of the output token.
|
||||
(Token id -> logP(x_i+1 | x_0, ..., x_i))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
parent_seq_id: int,
|
||||
output_token: int,
|
||||
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
|
||||
logprobs: Dict[int, float],
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.parent_seq_id = parent_seq_id
|
||||
@@ -234,15 +285,15 @@ class SequenceOutputs:
|
||||
self.logprobs = logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'SequenceOutputs(seq_id={self.seq_id}, '
|
||||
f'parent_seq_id={self.parent_seq_id}, '
|
||||
f'output_token={self.output_token}), '
|
||||
f'logprobs={self.logprobs}')
|
||||
return (f"SequenceOutputs(seq_id={self.seq_id}, "
|
||||
f"parent_seq_id={self.parent_seq_id}, "
|
||||
f"output_token={self.output_token}), "
|
||||
f"logprobs={self.logprobs}")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceOutputs):
|
||||
return NotImplemented
|
||||
return (self.seq_id == other.seq_id and
|
||||
self.parent_seq_id == other.parent_seq_id and
|
||||
self.output_token == other.output_token and
|
||||
self.logprobs == other.logprobs)
|
||||
return (self.seq_id == other.seq_id
|
||||
and self.parent_seq_id == other.parent_seq_id
|
||||
and self.output_token == other.output_token
|
||||
and self.logprobs == other.logprobs)
|
||||
|
||||
Reference in New Issue
Block a user