Fix various issues of async servers (#135)
This commit is contained in:
@@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
|
||||
SWAPPED = enum.auto()
|
||||
FINISHED_STOPPED = enum.auto()
|
||||
FINISHED_LENGTH_CAPPED = enum.auto()
|
||||
FINISHED_ABORTED = enum.auto()
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "SequenceStatus") -> bool:
|
||||
return status in [
|
||||
SequenceStatus.FINISHED_STOPPED,
|
||||
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
||||
SequenceStatus.FINISHED_ABORTED,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum):
|
||||
finish_reason = "stop"
|
||||
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
|
||||
finish_reason = "length"
|
||||
elif status == SequenceStatus.FINISHED_ABORTED:
|
||||
finish_reason = "abort"
|
||||
else:
|
||||
finish_reason = None
|
||||
return finish_reason
|
||||
|
||||
|
||||
class SequenceData:
|
||||
|
||||
def __init__(
|
||||
@@ -137,6 +142,9 @@ class Sequence:
|
||||
def get_cumulative_logprob(self) -> float:
|
||||
return self.data.cumulative_logprob
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return SequenceStatus.is_finished(self.status)
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> None:
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
@@ -182,7 +190,7 @@ class SequenceGroup:
|
||||
raise ValueError(f'Sequence {seq_id} not found.')
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||
|
||||
Reference in New Issue
Block a user