[Spec Decode] Reduce TP communication for speculative decoding draft token generation (#34049)
Signed-off-by: qizixi <qizixi@meta.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
@@ -208,6 +208,23 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.model(input_ids, positions, hidden_states, inputs_embeds)
|
||||
|
||||
def get_top_tokens(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Vocab-parallel argmax without all-gathering full logits.
|
||||
|
||||
Falls back to full logits when draft_id_to_target_id remapping is
|
||||
active, since the shared lm_head covers the full target vocab but
|
||||
the draft model only predicts over a subset (draft_vocab_size).
|
||||
"""
|
||||
if (
|
||||
hasattr(self, "draft_id_to_target_id")
|
||||
and self.draft_id_to_target_id is not None
|
||||
):
|
||||
return self.compute_logits(hidden_states).argmax(dim=-1)
|
||||
return self.logits_processor.get_top_tokens(self.lm_head, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
|
||||
def transform(inputs):
|
||||
name, loaded_weight = inputs
|
||||
|
||||
Reference in New Issue
Block a user