[Spec Decode] Reduce TP communication for speculative decoding draft token generation (#34049)
Signed-off-by: qizixi <qizixi@meta.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
@@ -109,6 +109,11 @@ class SpeculativeConfig:
|
||||
speculative input batches can contain sequences of different lengths,
|
||||
which may only be supported by certain attention backends. This currently
|
||||
only affects the EAGLE method of speculation."""
|
||||
use_local_argmax_reduction: bool = False
|
||||
"""Use vocab-parallel local argmax instead of all-gathering full logits
|
||||
for draft token generation. Reduces communication from O(vocab_size) to
|
||||
O(2 * tp_size) per token. Only applies to greedy draft selection in
|
||||
non-tree speculation."""
|
||||
|
||||
# Ngram proposer configuration
|
||||
prompt_lookup_max: int | None = Field(default=None, ge=1)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_gather,
|
||||
)
|
||||
@@ -102,6 +103,58 @@ class LogitsProcessor(CustomOp):
|
||||
logits = logits[..., : self.org_vocab_size]
|
||||
return logits
|
||||
|
||||
def get_top_tokens(
|
||||
self,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
embedding_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Vocab-parallel argmax without all-gathering full logits.
|
||||
|
||||
Each TP rank computes local argmax, then only the (value, index) pairs
|
||||
are gathered and reduced. Communication: O(batch * 2 * tp_size) vs
|
||||
O(batch * vocab_size).
|
||||
"""
|
||||
if self.scale <= 0.0 and self.scale != 1.0:
|
||||
raise ValueError(
|
||||
"The local argmax reduction optimization is not supported for "
|
||||
"non-positive logit scaling factors."
|
||||
)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias)
|
||||
if self.soft_cap is not None:
|
||||
logits = torch.tanh(logits / self.soft_cap) * self.soft_cap
|
||||
if self.scale != 1.0:
|
||||
logits = logits * self.scale
|
||||
|
||||
# Mask out padding entries beyond org_vocab_size on this shard.
|
||||
num_pad = lm_head.shard_indices.num_org_vocab_padding
|
||||
if num_pad > 0:
|
||||
logits[..., -num_pad:] = -float("inf")
|
||||
|
||||
local_max_vals, local_max_indices = logits.max(dim=-1)
|
||||
|
||||
# Convert shard-local indices to global vocab indices.
|
||||
vocab_start = lm_head.shard_indices.org_vocab_start_index
|
||||
global_indices = local_max_indices + vocab_start
|
||||
|
||||
if tp_size == 1:
|
||||
return global_indices
|
||||
|
||||
# All-gather (value, index) pairs, then reduce to global argmax.
|
||||
# Use float32 to avoid bf16 precision loss on large vocab indices.
|
||||
local_pair = torch.stack(
|
||||
[local_max_vals.float(), global_indices.float()], dim=-1
|
||||
)
|
||||
# [batch, 2] -> [batch, 2 * tp_size]
|
||||
gathered = tensor_model_parallel_all_gather(local_pair, dim=-1)
|
||||
# [batch, tp_size, 2] where [:, :, 0]=values, [:, :, 1]=indices
|
||||
gathered = gathered.view(hidden_states.shape[0], tp_size, 2)
|
||||
max_rank_idx = gathered[:, :, 0].argmax(dim=-1, keepdim=True)
|
||||
top_tokens = gathered[:, :, 1].gather(dim=-1, index=max_rank_idx)
|
||||
return top_tokens.squeeze(-1).to(torch.int64)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"vocab_size={self.vocab_size}"
|
||||
s += f", org_vocab_size={self.org_vocab_size}"
|
||||
|
||||
@@ -208,6 +208,23 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.model(input_ids, positions, hidden_states, inputs_embeds)
|
||||
|
||||
def get_top_tokens(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Vocab-parallel argmax without all-gathering full logits.
|
||||
|
||||
Falls back to full logits when draft_id_to_target_id remapping is
|
||||
active, since the shared lm_head covers the full target vocab but
|
||||
the draft model only predicts over a subset (draft_vocab_size).
|
||||
"""
|
||||
if (
|
||||
hasattr(self, "draft_id_to_target_id")
|
||||
and self.draft_id_to_target_id is not None
|
||||
):
|
||||
return self.compute_logits(hidden_states).argmax(dim=-1)
|
||||
return self.logits_processor.get_top_tokens(self.lm_head, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
|
||||
def transform(inputs):
|
||||
name, loaded_weight = inputs
|
||||
|
||||
@@ -99,6 +99,9 @@ class SpecDecodeBaseProposer:
|
||||
self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None
|
||||
if self.parallel_drafting:
|
||||
self._init_parallel_drafting_params()
|
||||
self.use_local_argmax_reduction: bool = (
|
||||
self.speculative_config.use_local_argmax_reduction
|
||||
)
|
||||
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
@@ -369,6 +372,12 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
|
||||
|
||||
def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Greedy-sample draft tokens from hidden states."""
|
||||
if self.use_local_argmax_reduction:
|
||||
return self.model.get_top_tokens(hidden_states)
|
||||
return self.model.compute_logits(hidden_states).argmax(dim=-1)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
@@ -491,11 +500,10 @@ class SpecDecodeBaseProposer:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
|
||||
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1 or self.parallel_drafting:
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
||||
return draft_token_ids.view(-1, self.num_speculative_tokens)
|
||||
|
||||
if self.uses_mrope:
|
||||
@@ -513,7 +521,8 @@ class SpecDecodeBaseProposer:
|
||||
hidden_states = hidden_states[token_indices_to_sample]
|
||||
|
||||
if isinstance(attn_metadata, TreeAttentionMetadata):
|
||||
# Draft using tree attention.
|
||||
# Draft using tree attention - requires full logits for top-k
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
draft_token_ids_list = self.propose_tree(
|
||||
batch_size=batch_size,
|
||||
logits=logits,
|
||||
@@ -525,7 +534,7 @@ class SpecDecodeBaseProposer:
|
||||
# [batch_size, num_tree_tokens]
|
||||
return torch.cat(draft_token_ids_list, dim=1)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
||||
|
||||
if self.allowed_attn_types is not None and not isinstance(
|
||||
attn_metadata, self.allowed_attn_types
|
||||
@@ -690,8 +699,7 @@ class SpecDecodeBaseProposer:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size])
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size])
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
@@ -1521,6 +1529,31 @@ class SpecDecodeBaseProposer:
|
||||
"Shared target model lm_head with MTP shared_head.head."
|
||||
)
|
||||
|
||||
if self.use_local_argmax_reduction:
|
||||
if not hasattr(self.model, "get_top_tokens"):
|
||||
raise ValueError(
|
||||
"use_local_argmax_reduction is enabled but draft model "
|
||||
f"{self.model.__class__.__name__} does not implement "
|
||||
"get_top_tokens()."
|
||||
)
|
||||
# Warn if draft model has vocab remapping, which forces fallback
|
||||
# to the full-logits path (negating the optimization).
|
||||
if (
|
||||
hasattr(self.model, "draft_id_to_target_id")
|
||||
and self.model.draft_id_to_target_id is not None
|
||||
):
|
||||
logger.warning(
|
||||
"use_local_argmax_reduction is enabled but draft model "
|
||||
"uses draft_id_to_target_id vocab remapping. The "
|
||||
"optimization will be bypassed (falling back to full "
|
||||
"logits gather + argmax)."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Using local argmax reduction for draft token generation "
|
||||
"(communication: O(2*tp_size) vs O(vocab_size))."
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user