[Speculative Decoding] EAGLE Implementation with Top-1 proposer (#6830)
This commit is contained in:
@@ -1092,6 +1092,10 @@ class SamplerOutput(
|
||||
# Optional last hidden states from the model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Optional prefill hidden states from the model
|
||||
# (used for models like EAGLE).
|
||||
prefill_hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Time taken in the forward pass for this across all workers
|
||||
model_forward_time: Optional[float] = None
|
||||
|
||||
@@ -1176,40 +1180,87 @@ class HiddenStates(msgspec.Struct, array_like=True,
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Hidden states corresponding to in-progress sequences.
|
||||
Used in speculative decoding to pass hidden states from
|
||||
the target model to the proposer model in the subsequent step.
|
||||
the target model to the proposer model.
|
||||
|
||||
seq_ids are the sequence ids of each entry of the batch
|
||||
dimension of the hidden_states tensor"""
|
||||
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
# Scorer hidden states. For prefill step, it is used for hidden states of
|
||||
# all tokens, whereas for decode step, it use used for last accepted tokens.
|
||||
hidden_states: torch.Tensor
|
||||
# The sequence group metadata list. Only needed for decode step.
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
# Scorer hidden states of the 2nd last token proposed by the proposer (
|
||||
# irrespective of whether it was accepted or not). Only used for cases when
|
||||
# last proposed token is accepted (i.e., in case of bonus tokens). For the
|
||||
# case of no bonus tokens, these are ignored.
|
||||
second_last_token_hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
_seq_ids: List[int] = msgspec.field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
|
||||
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
|
||||
if self.seq_group_metadata_list is not None:
|
||||
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
|
||||
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
|
||||
|
||||
@property
|
||||
def seq_ids(self) -> List[int]:
|
||||
return self._seq_ids
|
||||
|
||||
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""Update hidden states from target model invocation."""
|
||||
def update(self,
|
||||
hidden_states: torch.Tensor,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
second_last_token_hidden_states: Optional[torch.Tensor] = None):
|
||||
"""Update hidden states from target model invocation. Only used for
|
||||
decode steps"""
|
||||
assert len(seq_group_metadata_list) == len(hidden_states)
|
||||
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
||||
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
|
||||
|
||||
if self.second_last_token_hidden_states is not None:
|
||||
# Adding dummy hidden_states to this to maintain same shape
|
||||
self.second_last_token_hidden_states = torch.cat([
|
||||
self.second_last_token_hidden_states,
|
||||
torch.zeros_like(hidden_states)
|
||||
if second_last_token_hidden_states is None else
|
||||
second_last_token_hidden_states
|
||||
])
|
||||
|
||||
def prune(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||
"""Prune to provided list of sequence ids."""
|
||||
"""Prune to provided list of sequence ids. Only used for decode steps.
|
||||
"""
|
||||
# Currently this prunes all seq_ids not present in
|
||||
# seq_group_metadata_list which might cause problems where a sequence
|
||||
# may be "paused" then "resumed" later. This should only prune sequences
|
||||
# which are confirmed to be aborted.
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
if seq_ids != self._seq_ids:
|
||||
# Batch contents changed - prune removed sequences.
|
||||
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
|
||||
self.hidden_states = self.hidden_states[index]
|
||||
if self.second_last_token_hidden_states is not None:
|
||||
self.second_last_token_hidden_states = self\
|
||||
.second_last_token_hidden_states[index]
|
||||
self._seq_ids = seq_ids
|
||||
|
||||
def expand_with_bonus_tokens(
|
||||
self, seq_with_bonus_token_in_last_step: set) -> None:
|
||||
"""Expand hidden states for sequences with bonus tokens. This is in
|
||||
alignment with `MultiStepWorker._expand_execute_model_request`."""
|
||||
if self.second_last_token_hidden_states is None \
|
||||
or not seq_with_bonus_token_in_last_step:
|
||||
return
|
||||
|
||||
index = []
|
||||
for seq_id in self._seq_ids:
|
||||
i = self._seq_ids.index(seq_id)
|
||||
if seq_id in seq_with_bonus_token_in_last_step:
|
||||
index.append(i + len(self._seq_ids))
|
||||
index.append(i)
|
||||
|
||||
self.hidden_states = torch.cat(
|
||||
[self.hidden_states, self.second_last_token_hidden_states])[index]
|
||||
|
||||
|
||||
class ExecuteModelRequest(
|
||||
msgspec.Struct,
|
||||
|
||||
Reference in New Issue
Block a user