Support beam search & parallel generation (#7)
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import copy
|
||||
import enum
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from cacheflow.block import LogicalTokenBlock
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class SequenceStatus(enum.Enum):
|
||||
@@ -24,9 +26,11 @@ class Sequence:
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
# Initialize the logical token blocks with the given token ids.
|
||||
self.append(token_ids)
|
||||
self.add(token_ids)
|
||||
|
||||
self.status = SequenceStatus.PENDING
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.cumulative_logprobs = 1.0
|
||||
|
||||
def add_block(self) -> None:
|
||||
block = LogicalTokenBlock(
|
||||
@@ -35,7 +39,7 @@ class Sequence:
|
||||
)
|
||||
self.logical_token_blocks.append(block)
|
||||
|
||||
def append(self, token_ids: List[int]) -> None:
|
||||
def add(self, token_ids: List[int]) -> None:
|
||||
while token_ids:
|
||||
if not self.logical_token_blocks:
|
||||
self.add_block()
|
||||
@@ -49,6 +53,12 @@ class Sequence:
|
||||
last_block.append(token_ids[:num_empty_slots])
|
||||
token_ids = token_ids[num_empty_slots:]
|
||||
|
||||
def append(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
||||
assert token_id in logprobs
|
||||
self.add([token_id])
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.cumulative_logprobs += logprobs[token_id]
|
||||
|
||||
def get_len(self) -> int:
|
||||
return sum(block.num_tokens for block in self.logical_token_blocks)
|
||||
|
||||
@@ -58,6 +68,14 @@ class Sequence:
|
||||
token_ids.extend(block.get_token_ids())
|
||||
return token_ids
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
return self.logical_token_blocks[-1].get_last_token_id()
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> 'Sequence':
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
child_seq.cumulative_logprobs = self.cumulative_logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'Sequence(seq_id={self.seq_id}, '
|
||||
f'status={self.status.name}, '
|
||||
@@ -74,11 +92,17 @@ class SequenceGroup:
|
||||
self.group_id = group_id
|
||||
self.seqs = seqs
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
def get_seqs(
|
||||
self,
|
||||
status: Optional[SequenceStatus] = None,
|
||||
) -> List[Sequence]:
|
||||
if status is None:
|
||||
return len(self.seqs)
|
||||
return self.seqs
|
||||
else:
|
||||
return len([seq for seq in self.seqs if seq.status == status])
|
||||
return [seq for seq in self.seqs if seq.status == status]
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def find(self, seq_id: int) -> Sequence:
|
||||
for seq in self.seqs:
|
||||
@@ -92,3 +116,45 @@ class SequenceGroup:
|
||||
def __repr__(self) -> str:
|
||||
return (f'SequenceGroup(group_id={self.group_id}, '
|
||||
f'num_seqs={len(self.seqs)})')
|
||||
|
||||
|
||||
class SequenceGroupInputs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_id: int,
|
||||
is_prompt: bool,
|
||||
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
|
||||
context_len: int,
|
||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
|
||||
) -> None:
|
||||
self.group_id = group_id
|
||||
self.is_prompt = is_prompt
|
||||
self.input_tokens = input_tokens
|
||||
self.context_len = context_len
|
||||
self.seq_logprobs = seq_logprobs
|
||||
self.sampling_params = sampling_params
|
||||
self.block_tables = block_tables
|
||||
|
||||
|
||||
class SequenceOutputs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
parent_seq_id: int,
|
||||
output_token: int,
|
||||
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.parent_seq_id = parent_seq_id
|
||||
self.output_token = output_token
|
||||
self.logprobs = logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'SequenceOutputs(seq_id={self.seq_id}, '
|
||||
f'parent_seq_id={self.parent_seq_id}, '
|
||||
f'output_token={self.output_token}), '
|
||||
f'logprobs={self.logprobs}')
|
||||
|
||||
Reference in New Issue
Block a user