Fix various issues of async servers (#135)

This commit is contained in:
Zhuohan Li
2023-06-05 23:44:50 +08:00
committed by GitHub
parent 8274ca23ac
commit 1a956e136b
11 changed files with 289 additions and 121 deletions

View File

@@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
SWAPPED = enum.auto()
FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()
@staticmethod
def is_finished(status: "SequenceStatus") -> bool:
return status in [
SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED,
]
@staticmethod
@@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum):
finish_reason = "stop"
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
finish_reason = "length"
elif status == SequenceStatus.FINISHED_ABORTED:
finish_reason = "abort"
else:
finish_reason = None
return finish_reason
class SequenceData:
def __init__(
@@ -137,6 +142,9 @@ class Sequence:
def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob
def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: 'Sequence') -> None:
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
@@ -182,7 +190,7 @@ class SequenceGroup:
raise ValueError(f'Sequence {seq_id} not found.')
def is_finished(self) -> bool:
return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
return all(seq.is_finished() for seq in self.seqs)
def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, "