[Spec Decode] Reduce TP communication for speculative decoding draft token generation (#34049)

Signed-off-by: qizixi <qizixi@meta.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
qizixi
2026-02-22 14:59:16 -08:00
committed by GitHub
parent b7892a3bef
commit 2bcf71b9c0
4 changed files with 114 additions and 6 deletions

View File

@@ -5,6 +5,7 @@
import torch
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_gather,
)
@@ -102,6 +103,58 @@ class LogitsProcessor(CustomOp):
logits = logits[..., : self.org_vocab_size]
return logits
def get_top_tokens(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
embedding_bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""Vocab-parallel argmax without all-gathering full logits.
Each TP rank computes local argmax, then only the (value, index) pairs
are gathered and reduced. Communication: O(batch * 2 * tp_size) vs
O(batch * vocab_size).
"""
if self.scale <= 0.0 and self.scale != 1.0:
raise ValueError(
"The local argmax reduction optimization is not supported for "
"non-positive logit scaling factors."
)
tp_size = get_tensor_model_parallel_world_size()
logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias)
if self.soft_cap is not None:
logits = torch.tanh(logits / self.soft_cap) * self.soft_cap
if self.scale != 1.0:
logits = logits * self.scale
# Mask out padding entries beyond org_vocab_size on this shard.
num_pad = lm_head.shard_indices.num_org_vocab_padding
if num_pad > 0:
logits[..., -num_pad:] = -float("inf")
local_max_vals, local_max_indices = logits.max(dim=-1)
# Convert shard-local indices to global vocab indices.
vocab_start = lm_head.shard_indices.org_vocab_start_index
global_indices = local_max_indices + vocab_start
if tp_size == 1:
return global_indices
# All-gather (value, index) pairs, then reduce to global argmax.
# Use float32 to avoid bf16 precision loss on large vocab indices.
local_pair = torch.stack(
[local_max_vals.float(), global_indices.float()], dim=-1
)
# [batch, 2] -> [batch, 2 * tp_size]
gathered = tensor_model_parallel_all_gather(local_pair, dim=-1)
# [batch, tp_size, 2] where [:, :, 0]=values, [:, :, 1]=indices
gathered = gathered.view(hidden_states.shape[0], tp_size, 2)
max_rank_idx = gathered[:, :, 0].argmax(dim=-1, keepdim=True)
top_tokens = gathered[:, :, 1].gather(dim=-1, index=max_rank_idx)
return top_tokens.squeeze(-1).to(torch.int64)
def extra_repr(self) -> str:
s = f"vocab_size={self.vocab_size}"
s += f", org_vocab_size={self.org_vocab_size}"

View File

@@ -208,6 +208,23 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states, inputs_embeds)
def get_top_tokens(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
"""Vocab-parallel argmax without all-gathering full logits.
Falls back to full logits when draft_id_to_target_id remapping is
active, since the shared lm_head covers the full target vocab but
the draft model only predicts over a subset (draft_vocab_size).
"""
if (
hasattr(self, "draft_id_to_target_id")
and self.draft_id_to_target_id is not None
):
return self.compute_logits(hidden_states).argmax(dim=-1)
return self.logits_processor.get_top_tokens(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
def transform(inputs):
name, loaded_weight = inputs