[Model] MLPSpeculator speculative decoding support (#4947)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
This commit is contained in:
committed by
GitHub
parent
6c5b7af152
commit
b12518d3cf
@@ -794,6 +794,9 @@ class SamplerOutput:
|
||||
# Spec decode metrics populated by workers.
|
||||
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
|
||||
# Optional last hidden states from the model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self.outputs[idx]
|
||||
|
||||
@@ -842,6 +845,46 @@ class PoolerOutput:
|
||||
self.__class__) and self.outputs == other.outputs
|
||||
|
||||
|
||||
def get_all_seq_ids(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||
sequence ids.
|
||||
"""
|
||||
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
||||
|
||||
|
||||
class HiddenStates:
|
||||
"""Hidden states corresponding to in-progress sequences.
|
||||
Used in speculative decoding to pass hidden states from
|
||||
the target model to the proposer model in the subsequent step.
|
||||
|
||||
seq_ids are the sequence ids of each entry of the batch
|
||||
dimension of the hidden_states tensor"""
|
||||
|
||||
def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
hidden_states: torch.Tensor):
|
||||
assert len(seq_group_metadata_list) == len(hidden_states)
|
||||
self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
|
||||
self.hidden_states: torch.Tensor = hidden_states
|
||||
|
||||
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""Update hidden states from target model invocation."""
|
||||
assert len(seq_group_metadata_list) == len(hidden_states)
|
||||
self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
||||
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
|
||||
|
||||
def prune(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||
"""Prune to provided list of sequence ids."""
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
if seq_ids != self.seq_ids:
|
||||
# Batch contents changed - prune removed sequences.
|
||||
index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
|
||||
self.hidden_states = self.hidden_states[index]
|
||||
self.seq_ids = seq_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteModelRequest:
|
||||
"""The model execution request."""
|
||||
@@ -857,6 +900,8 @@ class ExecuteModelRequest:
|
||||
num_lookahead_slots: int = 0
|
||||
# The number of requests in the running queue.
|
||||
running_queue_size: int = 0
|
||||
# Optional hidden states from prior step.
|
||||
previous_hidden_states: Optional[HiddenStates] = None
|
||||
|
||||
def clone(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
@@ -869,4 +914,5 @@ class ExecuteModelRequest:
|
||||
blocks_to_copy=self.blocks_to_copy.copy(),
|
||||
num_lookahead_slots=self.num_lookahead_slots,
|
||||
running_queue_size=self.running_queue_size,
|
||||
previous_hidden_states=self.previous_hidden_states,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user