diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 847e846d4..207d8c2f6 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -109,6 +109,11 @@ class SpeculativeConfig: speculative input batches can contain sequences of different lengths, which may only be supported by certain attention backends. This currently only affects the EAGLE method of speculation.""" + use_local_argmax_reduction: bool = False + """Use vocab-parallel local argmax instead of all-gathering full logits + for draft token generation. Reduces communication from O(vocab_size) to + O(2 * tp_size) per token. Only applies to greedy draft selection in + non-tree speculation.""" # Ngram proposer configuration prompt_lookup_max: int | None = Field(default=None, ge=1) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 38753b0fc..dd2a61bc6 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -5,6 +5,7 @@ import torch from vllm.distributed import ( + get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, tensor_model_parallel_gather, ) @@ -102,6 +103,58 @@ class LogitsProcessor(CustomOp): logits = logits[..., : self.org_vocab_size] return logits + def get_top_tokens( + self, + lm_head: VocabParallelEmbedding, + hidden_states: torch.Tensor, + embedding_bias: torch.Tensor | None = None, + ) -> torch.Tensor: + """Vocab-parallel argmax without all-gathering full logits. + + Each TP rank computes local argmax, then only the (value, index) pairs + are gathered and reduced. Communication: O(batch * 2 * tp_size) vs + O(batch * vocab_size). + """ + if self.scale <= 0.0 and self.scale != 1.0: + raise ValueError( + "The local argmax reduction optimization is not supported for " + "non-positive logit scaling factors." + ) + tp_size = get_tensor_model_parallel_world_size() + + logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias) + if self.soft_cap is not None: + logits = torch.tanh(logits / self.soft_cap) * self.soft_cap + if self.scale != 1.0: + logits = logits * self.scale + + # Mask out padding entries beyond org_vocab_size on this shard. + num_pad = lm_head.shard_indices.num_org_vocab_padding + if num_pad > 0: + logits[..., -num_pad:] = -float("inf") + + local_max_vals, local_max_indices = logits.max(dim=-1) + + # Convert shard-local indices to global vocab indices. + vocab_start = lm_head.shard_indices.org_vocab_start_index + global_indices = local_max_indices + vocab_start + + if tp_size == 1: + return global_indices + + # All-gather (value, index) pairs, then reduce to global argmax. + # Use float32 to avoid bf16 precision loss on large vocab indices. + local_pair = torch.stack( + [local_max_vals.float(), global_indices.float()], dim=-1 + ) + # [batch, 2] -> [batch, 2 * tp_size] + gathered = tensor_model_parallel_all_gather(local_pair, dim=-1) + # [batch, tp_size, 2] where [:, :, 0]=values, [:, :, 1]=indices + gathered = gathered.view(hidden_states.shape[0], tp_size, 2) + max_rank_idx = gathered[:, :, 0].argmax(dim=-1, keepdim=True) + top_tokens = gathered[:, :, 1].gather(dim=-1, index=max_rank_idx) + return top_tokens.squeeze(-1).to(torch.int64) + def extra_repr(self) -> str: s = f"vocab_size={self.vocab_size}" s += f", org_vocab_size={self.org_vocab_size}" diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 02f5b5ff6..6c7b53d4d 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -208,6 +208,23 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states, inputs_embeds) + def get_top_tokens( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """Vocab-parallel argmax without all-gathering full logits. + + Falls back to full logits when draft_id_to_target_id remapping is + active, since the shared lm_head covers the full target vocab but + the draft model only predicts over a subset (draft_vocab_size). + """ + if ( + hasattr(self, "draft_id_to_target_id") + and self.draft_id_to_target_id is not None + ): + return self.compute_logits(hidden_states).argmax(dim=-1) + return self.logits_processor.get_top_tokens(self.lm_head, hidden_states) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: def transform(inputs): name, loaded_weight = inputs diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 04450e989..a46ba8f90 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -99,6 +99,9 @@ class SpecDecodeBaseProposer: self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None if self.parallel_drafting: self._init_parallel_drafting_params() + self.use_local_argmax_reduction: bool = ( + self.speculative_config.use_local_argmax_reduction + ) max_batch_size = vllm_config.scheduler_config.max_num_seqs self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens @@ -369,6 +372,12 @@ class SpecDecodeBaseProposer: self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) + def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Greedy-sample draft tokens from hidden states.""" + if self.use_local_argmax_reduction: + return self.model.get_top_tokens(hidden_states) + return self.model.compute_logits(hidden_states).argmax(dim=-1) + def propose( self, # [num_tokens] @@ -491,11 +500,10 @@ class SpecDecodeBaseProposer: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[token_indices_to_sample] - logits = self.model.compute_logits(sample_hidden_states) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1 or self.parallel_drafting: - draft_token_ids = logits.argmax(dim=-1) + draft_token_ids = self._greedy_sample(sample_hidden_states) return draft_token_ids.view(-1, self.num_speculative_tokens) if self.uses_mrope: @@ -513,7 +521,8 @@ class SpecDecodeBaseProposer: hidden_states = hidden_states[token_indices_to_sample] if isinstance(attn_metadata, TreeAttentionMetadata): - # Draft using tree attention. + # Draft using tree attention - requires full logits for top-k + logits = self.model.compute_logits(sample_hidden_states) draft_token_ids_list = self.propose_tree( batch_size=batch_size, logits=logits, @@ -525,7 +534,7 @@ class SpecDecodeBaseProposer: # [batch_size, num_tree_tokens] return torch.cat(draft_token_ids_list, dim=1) - draft_token_ids = logits.argmax(dim=-1) + draft_token_ids = self._greedy_sample(sample_hidden_states) if self.allowed_attn_types is not None and not isinstance( attn_metadata, self.allowed_attn_types @@ -690,8 +699,7 @@ class SpecDecodeBaseProposer: last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size]) - draft_token_ids = logits.argmax(dim=-1) + draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size]) draft_token_ids_list.append(draft_token_ids) # [batch_size, num_speculative_tokens] @@ -1521,6 +1529,31 @@ class SpecDecodeBaseProposer: "Shared target model lm_head with MTP shared_head.head." ) + if self.use_local_argmax_reduction: + if not hasattr(self.model, "get_top_tokens"): + raise ValueError( + "use_local_argmax_reduction is enabled but draft model " + f"{self.model.__class__.__name__} does not implement " + "get_top_tokens()." + ) + # Warn if draft model has vocab remapping, which forces fallback + # to the full-logits path (negating the optimization). + if ( + hasattr(self.model, "draft_id_to_target_id") + and self.model.draft_id_to_target_id is not None + ): + logger.warning( + "use_local_argmax_reduction is enabled but draft model " + "uses draft_id_to_target_id vocab remapping. The " + "optimization will be bypassed (falling back to full " + "logits gather + argmax)." + ) + else: + logger.info( + "Using local argmax reduction for draft token generation " + "(communication: O(2*tp_size) vs O(vocab_size))." + ) + @torch.inference_mode() def dummy_run( self,