[Model] MLPSpeculator speculative decoding support (#4947)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
This commit is contained in:
Joshua Rosenkranz
2024-06-20 20:23:12 -04:00
committed by GitHub
parent 6c5b7af152
commit b12518d3cf
18 changed files with 523 additions and 40 deletions

View File

@@ -794,6 +794,9 @@ class SamplerOutput:
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
def __getitem__(self, idx: int):
return self.outputs[idx]
@@ -842,6 +845,46 @@ class PoolerOutput:
self.__class__) and self.outputs == other.outputs
def get_all_seq_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
class HiddenStates:
"""Hidden states corresponding to in-progress sequences.
Used in speculative decoding to pass hidden states from
the target model to the proposer model in the subsequent step.
seq_ids are the sequence ids of each entry of the batch
dimension of the hidden_states tensor"""
def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor):
assert len(seq_group_metadata_list) == len(hidden_states)
self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
self.hidden_states: torch.Tensor = hidden_states
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor) -> None:
"""Update hidden states from target model invocation."""
assert len(seq_group_metadata_list) == len(hidden_states)
self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids."""
seq_ids = get_all_seq_ids(seq_group_metadata_list)
if seq_ids != self.seq_ids:
# Batch contents changed - prune removed sequences.
index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
self.hidden_states = self.hidden_states[index]
self.seq_ids = seq_ids
@dataclass
class ExecuteModelRequest:
"""The model execution request."""
@@ -857,6 +900,8 @@ class ExecuteModelRequest:
num_lookahead_slots: int = 0
# The number of requests in the running queue.
running_queue_size: int = 0
# Optional hidden states from prior step.
previous_hidden_states: Optional[HiddenStates] = None
def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
@@ -869,4 +914,5 @@ class ExecuteModelRequest:
blocks_to_copy=self.blocks_to_copy.copy(),
num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states,
)