diff --git a/tests/models/registry.py b/tests/models/registry.py index 64a0794b8..d139f707f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -1200,6 +1200,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { }, is_available_online=False, ), + "NemotronHMTPModel": _HfExamplesInfo( + "nvidia/Nemotron-Super-Placeholder", + speculative_model="nvidia/Nemotron-Super-Placeholder", + is_available_online=False, + ), } _TRANSFORMERS_BACKEND_MODELS = { diff --git a/tests/v1/attention/test_mamba_update_block_table.py b/tests/v1/attention/test_mamba_update_block_table.py index f60e690d5..923939053 100644 --- a/tests/v1/attention/test_mamba_update_block_table.py +++ b/tests/v1/attention/test_mamba_update_block_table.py @@ -41,6 +41,9 @@ def _make_vllm_config(block_size, max_model_len, max_num_seqs): cudagraph_mode=CUDAGraphMode.FULL, max_cudagraph_capture_size=None, ), + speculative_config=None, + num_speculative_tokens=0, + parallel_config=SimpleNamespace(decode_context_parallel_size=1), scheduler_config=SimpleNamespace(max_num_seqs=max_num_seqs), model_config=SimpleNamespace(max_model_len=max_model_len), ) @@ -92,7 +95,10 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers(): has_initial_states_p=None, query_start_loc_p=None, num_computed_tokens_p=None, - state_indices_tensor=builder_a.state_indices_tensor[:num_reqs], + state_indices_tensor_p=None, + query_start_loc_d=None, + num_accepted_tokens=None, + state_indices_tensor_d=builder_a.state_indices_tensor_d[:num_reqs], block_idx_last_scheduled_token=( builder_a.block_idx_last_scheduled_token[:num_reqs] ), diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 29f0380d1..c2bced784 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -36,6 +36,7 @@ MTPModelTypes = Literal[ "glm4_moe_lite_mtp", "glm_ocr_mtp", "ernie_mtp", + "nemotron_h_mtp", "exaone_moe_mtp", "qwen3_next_mtp", "qwen3_5_mtp", @@ -255,6 +256,19 @@ class SpeculativeConfig: {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]} ) + if ( + hf_config.model_type == "nemotron_h" + and hasattr(hf_config, "num_nextn_predict_layers") + and hf_config.num_nextn_predict_layers > 0 + ): + # Check if this is an MTP variant + hf_config.model_type = "nemotron_h_mtp" + if hf_config.model_type == "nemotron_h_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) + hf_config.update( + {"n_predict": n_predict, "architectures": ["NemotronHMTPModel"]} + ) + if hf_config.model_type == "qwen3_next": hf_config.model_type = "qwen3_next_mtp" if hf_config.model_type == "qwen3_next_mtp": @@ -325,7 +339,7 @@ class SpeculativeConfig: if self.target_model_config is None: raise ValueError("target_model_config must be present for mtp") if self.target_model_config.hf_text_config.model_type == "deepseek_v32": - # FIXME(luccafong): cudgraph with v32 MTP is not supported, + # FIXME(luccafong): cudagraph with v32 MTP is not supported, # remove this when the issue is fixed. self.enforce_eager = True # use the draft model from the same model: @@ -427,7 +441,7 @@ class SpeculativeConfig: self.method = "mtp" if self.num_speculative_tokens > 1: logger.warning( - "Enabling num_speculative_tokens > 1 will run" + "Enabling num_speculative_tokens > 1 will run " "multiple times of forward on same MTP layer" ",which may result in lower acceptance rate" ) @@ -712,6 +726,7 @@ class SpeculativeConfig: "hunyuan_vl", "hunyuan_v1_dense", "afmoe", + "nemotron_h", ] if ( self.method == "eagle3" diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index a9930c490..2a0c0679f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -395,6 +395,15 @@ class VllmConfig: ] return hash_str + @property + def num_speculative_tokens(self) -> int: + if ( + self.speculative_config is not None + and self.speculative_config.num_speculative_tokens is not None + ): + return self.speculative_config.num_speculative_tokens + return 0 + @property def needs_dp_coordinator(self) -> bool: """ diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 347ce139e..3c6b01394 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -41,14 +41,6 @@ class MambaBase(AttentionLayerBase): pass def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: - if ( - vllm_config.speculative_config is not None - and vllm_config.model_config.hf_config.model_type - not in ["qwen3_next", "qwen3_5", "qwen3_5_moe"] - ): - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet." - ) mamba_block_size = vllm_config.cache_config.mamba_block_size page_size_padded = vllm_config.cache_config.mamba_page_size_padded return MambaSpec( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index e2575a2b4..24e189a5c 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -265,7 +265,8 @@ class MambaMixer(MambaBase, PluggableLayer): attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, Mamba1AttentionMetadata) query_start_loc_p = attn_metadata.query_start_loc_p - state_indices_tensor = attn_metadata.state_indices_tensor + state_indices_tensor_p = attn_metadata.state_indices_tensor_p + state_indices_tensor_d = attn_metadata.state_indices_tensor_d self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] @@ -295,17 +296,13 @@ class MambaMixer(MambaBase, PluggableLayer): prefill_decode_split = split_batch_to_prefill_and_decode( hidden_states_BC, gate, - state_indices_tensor, num_prefill_tokens, - num_prefills, num_decode_tokens, ) hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d gate_p = prefill_decode_split.gate_p gate_d = prefill_decode_split.gate_d - state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p - state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d if is_mamba_cache_all: block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( @@ -477,16 +474,12 @@ class PrefillDecodeSplit(NamedTuple): hidden_states_BC_d: torch.Tensor gate_p: torch.Tensor gate_d: torch.Tensor - state_indices_tensor_p: torch.Tensor - state_indices_tensor_d: torch.Tensor def split_batch_to_prefill_and_decode( hidden_states_BC: torch.Tensor, gate: torch.Tensor, - state_indices_tensor: torch.Tensor, num_prefill_tokens: int, - num_prefills: int, num_decode_tokens: int, ) -> PrefillDecodeSplit: num_actual_tokens = num_prefill_tokens + num_decode_tokens @@ -501,20 +494,11 @@ def split_batch_to_prefill_and_decode( gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1 ) - # num_decode_tokens accounts for CUDA graph padding when applicable - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor[: num_decode_tokens + num_prefills], - [num_decode_tokens, num_prefills], - dim=0, - ) - return PrefillDecodeSplit( hidden_states_BC_p=hidden_states_BC_p, hidden_states_BC_d=hidden_states_BC_d, gate_p=gate_p, gate_d=gate_d, - state_indices_tensor_p=state_indices_tensor_p, - state_indices_tensor_d=state_indices_tensor_d, ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 775c60c86..971581d89 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -477,7 +477,8 @@ class MambaMixer2(MambaBase, PluggableLayer): dim=-1, ) - compilation_config = get_current_vllm_config().compilation_config + vllm_config = get_current_vllm_config() + compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self @@ -488,6 +489,8 @@ class MambaMixer2(MambaBase, PluggableLayer): self.cache_config = cache_config self.prefix = prefix + self.num_spec = vllm_config.num_speculative_tokens + # Pre-compute sizes for forward pass self.tped_intermediate_size = self.intermediate_size // self.tp_size self.tped_conv_size = self.conv_dim // self.tp_size @@ -576,7 +579,6 @@ class MambaMixer2(MambaBase, PluggableLayer): # conv_state = (..., dim, width-1) yet contiguous along 'dim' conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - state_indices_tensor = attn_metadata.state_indices_tensor has_initial_states_p = attn_metadata.has_initial_states_p prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size @@ -584,6 +586,12 @@ class MambaMixer2(MambaBase, PluggableLayer): query_start_loc_p = attn_metadata.query_start_loc_p cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p + state_indices_tensor_p = attn_metadata.state_indices_tensor_p + state_indices_tensor_d = attn_metadata.state_indices_tensor_d + num_accepted_tokens = attn_metadata.num_accepted_tokens + query_start_loc_d = attn_metadata.query_start_loc_d + num_decodes = attn_metadata.num_decodes + num_decode_tokens = attn_metadata.num_decode_tokens if attn_metadata is None: # profile run @@ -593,29 +601,21 @@ class MambaMixer2(MambaBase, PluggableLayer): hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C) return hidden_states - num_prefills = attn_metadata.num_prefills # request count - num_decodes = attn_metadata.num_decode_tokens # token count (=request) - num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + num_prefills = attn_metadata.num_prefills + num_prefill_tokens = attn_metadata.num_prefill_tokens has_prefill = num_prefills > 0 has_decode = num_decodes > 0 - num_actual_tokens = num_prefill_tokens + num_decodes + num_actual_tokens = num_prefill_tokens + num_decode_tokens - # Separate prefill and decode by splitting varlen input # Split along token dimension hidden_states_B_C_d, hidden_states_B_C_p = torch.split( hidden_states_B_C[:num_actual_tokens], - [num_decodes, num_prefill_tokens], + [num_decode_tokens, num_prefill_tokens], dim=0, ) dt_d, dt_p = torch.split( dt[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0, - ) - # Split along batch dimension - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor[:num_actual_tokens], - [num_decodes, num_prefills], + [num_decode_tokens, num_prefill_tokens], dim=0, ) @@ -642,16 +642,16 @@ class MambaMixer2(MambaBase, PluggableLayer): ) num_computed_tokens_p = attn_metadata.num_computed_tokens_p else: - block_idx_last_computed_token_d = None block_idx_last_computed_token_p = None - block_idx_last_scheduled_token_d = None block_idx_last_scheduled_token_p = None block_idx_first_scheduled_token_p = None + block_idx_last_scheduled_token_d = None + block_idx_last_computed_token_d = None num_computed_tokens_p = None preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( output[:num_actual_tokens], - [num_decodes, num_prefill_tokens], + [num_decode_tokens, num_prefill_tokens], dim=0, ) @@ -709,6 +709,7 @@ class MambaMixer2(MambaBase, PluggableLayer): ) # NOTE: final output is an in-place update of out tensor + assert preallocated_ssm_out_p is not None varlen_states = mamba_chunk_scan_combined_varlen( hidden_states_p.view( num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim @@ -840,6 +841,9 @@ class MambaMixer2(MambaBase, PluggableLayer): conv_state_indices=state_indices_tensor_d, block_idx_last_scheduled_token=block_idx_last_scheduled_token_d, initial_state_idx=block_idx_last_computed_token_d, + num_accepted_tokens=num_accepted_tokens, + query_start_loc=query_start_loc_d, + max_query_len=state_indices_tensor_d.size(-1), ) hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn( @@ -862,6 +866,7 @@ class MambaMixer2(MambaBase, PluggableLayer): -1, self.num_heads // self.tp_size, self.head_dim ) + assert preallocated_ssm_out_d is not None # - the hidden is reshaped into (bs, num_heads, head_dim) # - mamba_cache_params.ssm_state's slots will be selected # using state_indices_tensor_d @@ -879,7 +884,9 @@ class MambaMixer2(MambaBase, PluggableLayer): dt_softplus=True, state_batch_indices=state_indices_tensor_d_input, dst_state_batch_indices=state_indices_tensor_d_output, - out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), + out=preallocated_ssm_out_d.view(num_decode_tokens, -1, self.head_dim), + num_accepted_tokens=num_accepted_tokens, + cu_seqlens=query_start_loc_d, is_blackwell=self.is_blackwell, ) @@ -901,6 +908,7 @@ class MambaMixer2(MambaBase, PluggableLayer): head_dim=self.head_dim, state_size=self.ssm_state_size, conv_kernel=self.conv_kernel_size, + num_spec=self.num_spec, ) @property diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index d66dee7c9..fc8912f8c 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -133,6 +133,7 @@ class MambaStateShapeCalculator: head_dim: int, state_size: int, conv_kernel: int, + num_spec: int = 0, ) -> tuple[tuple[int, int], tuple[int, int, int]]: # if n_groups is not divisible by world_size, need to extend the shards # to ensure all groups needed by a head is sharded along with it @@ -141,7 +142,7 @@ class MambaStateShapeCalculator: conv_dim = intermediate_size + 2 * n_groups * state_size # contiguous along 'dim' axis - conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) + conv_state_shape = (conv_kernel - 1 + num_spec, divide(conv_dim, tp_world_size)) # These are not TP-ed as they depend on A, dt_bias, D # - they are typically small diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 157f9f346..b0c1ffb0d 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -1155,7 +1155,9 @@ def causal_conv1d_update( if conv_state_indices is None: assert conv_state.size(0) >= batch else: - assert (batch,) == conv_state_indices.shape + assert batch == conv_state_indices.shape[0], ( + f"ERROR: conv_state_indices should have shape ({batch},*) but got {conv_state_indices.shape}" + ) assert num_cache_lines >= batch assert weight.stride(1) == 1 # Need this diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 14e00bce2..2348af2d9 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -119,7 +119,8 @@ class ShortConv(MambaBase, CustomOp): assert isinstance(attn_metadata, ShortConvAttentionMetadata) self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) - state_indices_tensor = attn_metadata.state_indices_tensor + state_indices_tensor_p = attn_metadata.state_indices_tensor_p + state_indices_tensor_d = attn_metadata.state_indices_tensor_d has_initial_states_p = attn_metadata.has_initial_states_p query_start_loc_p = attn_metadata.query_start_loc_p @@ -163,13 +164,6 @@ class ShortConv(MambaBase, CustomOp): [num_decodes, num_prefill_tokens], dim=0, ) - # Split along batch dimension - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor, - [num_decodes, num_prefills], - dim=0, - ) - conv_output_list = [] if has_prefill: diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index f1c34abf2..deb20852a 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -228,6 +228,7 @@ class Mamba2ForCausalLM( head_dim=hf_config.head_dim, state_size=hf_config.state_size, conv_kernel=hf_config.conv_kernel, + num_spec=vllm_config.num_speculative_tokens, ) @classmethod diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 06141013c..f180e4acd 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -636,6 +636,9 @@ class NemotronHModel(nn.Module): hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states + def is_spec_layer(self, config: NemotronHConfig, weight_name: str) -> bool: + return weight_name.startswith("mtp.") + def _get_max_n_routed_experts(self) -> int: """Get max n_routed_experts from config or block_configs for puzzle models. @@ -702,6 +705,10 @@ class NemotronHModel(nn.Module): if name is None: continue + # Skip MTP/spec decode layers early (before stacked params mapping) + if name.startswith("mtp."): + continue + # load stacked params for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -845,6 +852,7 @@ class NemotronHForCausalLM( head_dim=hf_config.mamba_head_dim, state_size=hf_config.ssm_state_size, conv_kernel=hf_config.conv_kernel, + num_spec=vllm_config.num_speculative_tokens, ) @classmethod diff --git a/vllm/model_executor/models/nemotron_h_mtp.py b/vllm/model_executor/models/nemotron_h_mtp.py new file mode 100644 index 000000000..b994e2b0d --- /dev/null +++ b/vllm/model_executor/models/nemotron_h_mtp.py @@ -0,0 +1,503 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""NemotronH-MTP model with attention layers.""" + +import typing +from collections.abc import Callable, Iterable + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config.parallel import ParallelConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import ( + make_empty_intermediate_tensors_factory, + maybe_prefix, +) +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import NemotronHConfig + +from .interfaces import SupportsPP +from .nemotron_h import ( + NemotronHAttentionDecoderLayer, + NemotronHMoEDecoderLayer, +) + + +class NemotronHMTPAttentionDecoderLayer(NemotronHAttentionDecoderLayer): + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, + prefix: str = "", + has_start_projections: bool = False, + has_end_norm: bool = False, + ) -> None: + super().__init__( + config=config, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + parallel_config=parallel_config, + prefix=prefix, + ) + self.has_start_projections = has_start_projections + self.has_end_norm = has_end_norm + + if has_start_projections: + self.enorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.hnorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + # Fusion layer to combine embeddings with target hidden states + self.eh_proj = ColumnParallelLinear( + input_size=config.hidden_size * 2, + output_size=config.hidden_size, + bias=False, + gather_output=True, + params_dtype=config.dtype + if hasattr(config, "dtype") + else torch.bfloat16, + quant_config=quant_config, + prefix=f"{prefix}.eh_proj", + ) + + if has_end_norm: + self.final_layernorm = RMSNorm( + config.hidden_size, + eps=getattr(config, "layer_norm_epsilon", 1e-5), + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Start projections (Fusion) + if self.has_start_projections: + # Normalize both inputs before fusion + assert inputs_embeds is not None + inputs_embeds_normed = self.enorm(inputs_embeds) + previous_hidden_states_normed = self.hnorm(hidden_states) + + # Fuse via concatenation and linear projection + fused = torch.cat( + [inputs_embeds_normed, previous_hidden_states_normed], dim=-1 + ) + hidden_states, _ = self.eh_proj(fused) + + # Call parent forward (Attention) + # Parent forward expects: hidden_states, residual + hidden_states, residual = super().forward( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + # End norm + if self.has_end_norm: + if residual is not None: + hidden_states = hidden_states + residual + residual = None # Consumed residual + + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, residual + + +class NemotronHMTPMoEDecoderLayer(NemotronHMoEDecoderLayer): + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, + prefix: str = "", + has_start_projections: bool = False, + has_end_norm: bool = False, + ) -> None: + super().__init__( + config=config, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + parallel_config=parallel_config, + prefix=prefix, + ) + self.has_start_projections = has_start_projections + self.has_end_norm = has_end_norm + + if has_start_projections: + self.enorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.hnorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + # Fusion layer to combine embeddings with target hidden states + self.eh_proj = ColumnParallelLinear( + input_size=config.hidden_size * 2, + output_size=config.hidden_size, + bias=False, + gather_output=True, + params_dtype=config.dtype + if hasattr(config, "dtype") + else torch.bfloat16, + quant_config=quant_config, + prefix=f"{prefix}.eh_proj", + ) + + if has_end_norm: + self.final_layernorm = RMSNorm( + config.hidden_size, + eps=getattr(config, "layer_norm_epsilon", 1e-5), + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Start projections (Fusion) + if self.has_start_projections: + # Normalize both inputs before fusion + assert inputs_embeds is not None + inputs_embeds_normed = self.enorm(inputs_embeds) + previous_hidden_states_normed = self.hnorm(hidden_states) + + # Fuse via concatenation and linear projection + fused = torch.cat( + [inputs_embeds_normed, previous_hidden_states_normed], dim=-1 + ) + hidden_states, _ = self.eh_proj(fused) + + # Call parent forward (MoE) + hidden_states, residual = super().forward( + hidden_states=hidden_states, + residual=residual, + ) + + # End norm + if self.has_end_norm: + if residual is not None: + hidden_states = hidden_states + residual + residual = None # Consumed residual + + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class NemotronHMultiTokenPredictor(nn.Module): + """MTP predictor with NemotronH layers.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + + self.config = config + self.vocab_size = config.vocab_size + self.org_vocab_size = config.vocab_size + + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) + assert self.num_mtp_layers == 1, ( + "Only one MTP layer is supported for NemotronH-MTP" + ) + + self.pattern_str = config.mtp_hybrid_override_pattern + self.pattern_len = len(self.pattern_str) + assert self.pattern_len > 0 + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + + # Build flat list of layers + self.layers = torch.nn.ModuleDict() + + # Total number of physical layers = num_steps * pattern_len + total_layers = self.num_mtp_layers * self.pattern_len + for i in range(total_layers): + step_rel_idx = i % self.pattern_len + + char = self.pattern_str[step_rel_idx] + + is_start_of_step = step_rel_idx == 0 + is_end_of_step = step_rel_idx == self.pattern_len - 1 + + layer_prefix = f"{prefix}.layers.{i}" + + # TODO smor- remove double layers formation + common_kwargs = dict( + config=config, + layer_idx=self.mtp_start_layer_idx + i, + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + parallel_config=vllm_config.parallel_config, + prefix=layer_prefix, + has_start_projections=is_start_of_step, + has_end_norm=is_end_of_step, + ) + + if char == "*": + self.layers[str(i)] = NemotronHMTPAttentionDecoderLayer(**common_kwargs) + elif char == "E": + self.layers[str(i)] = NemotronHMTPMoEDecoderLayer(**common_kwargs) + else: + raise NotImplementedError( + f"Pattern char '{char}' in {self.pattern_str} not implemented" + ) + + self.make_empty_intermediate_tensors: Callable[..., IntermediateTensors] = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + assert self.embed_tokens is not None, ( + "embed_tokens not initialized - must be shared from target model" + ) + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + + residual = None + + for i in range(self.pattern_len): + hidden_states, residual = self.layers[str(i)]( + inputs_embeds=inputs_embeds, + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + return hidden_states + + +class NemotronHMTP(nn.Module, SupportsPP): + """NemotronH MTP model.""" + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.config = config + self.quant_config = vllm_config.quant_config + + # Needed for load_weights mapping + self.mtp_start_layer_idx = config.num_hidden_layers + + # EPLB config for experts + self.num_redundant_experts = 0 + if vllm_config.parallel_config and vllm_config.parallel_config.eplb_config: + self.num_redundant_experts = ( + vllm_config.parallel_config.eplb_config.num_redundant_experts + ) + + # MTP predictor + self.model = NemotronHMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp") + ) + + # LM head for generating logits + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + + self.logits_processor = LogitsProcessor(self.config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor: + """Forward - applies attention-based MTP.""" + hidden_states = self.model( + input_ids, + positions, + hidden_states, + intermediate_tensors, + inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + """Compute logits for DRAFT token generation.""" + assert self.lm_head is not None, ( + "lm_head not initialized - must be shared from target model" + ) + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load MTP weights with proper name remapping.""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = [] + if hasattr(self.config, "n_routed_experts") and self.config.n_routed_experts: + expert_params_mapping = FusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="up_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="", # Empty - non-gated MoE + num_experts=self.config.n_routed_experts, + num_redundant_experts=self.num_redundant_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # Only process MTP weights - skip all non-MTP weights + if ( + not name.startswith("mtp.") + and "embeddings" not in name + and "lm_head" not in name + ): + continue + # Skip rotary embeddings (computed, not loaded) + if "rotary_emb.inv_freq" in name: + continue + + name = name.replace("mtp.layers.", "model.layers.") + + if "embeddings" in name: + name = name.replace("embeddings", "embed_tokens") + if name.startswith("backbone."): + name = name.replace("backbone.", "model.") + + # Handle stacked parameters (qkv_proj) for attention layers + is_stacked = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + # Must be in a mixer (attention layer) + if ".mixer." not in name: + continue + + is_stacked = True + stacked_name = name.replace(weight_name, param_name) + + if stacked_name.endswith(".bias") and stacked_name not in params_dict: + continue + + if stacked_name not in params_dict: + # Might be that mapping failed or param doesn't exist + continue + + param = params_dict[stacked_name] + weight_loader = getattr(param, "weight_loader", None) + if weight_loader is not None: + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(stacked_name) + break + + if is_stacked: + continue + + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + # weight_name is like "experts.0.up_proj." + if weight_name not in name: + continue + + is_expert_weight = True + + # Replace the expert-specific weight name with fused parameter name + # e.g., "experts.0.up_proj." -> "experts.w13_" + name_mapped = name.replace(weight_name, param_name) + + if name_mapped not in params_dict: + continue + + param = params_dict[name_mapped] + weight_loader = typing.cast(Callable[..., bool], param.weight_loader) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + loaded_params.add(name_mapped) + break + + if is_expert_weight: + continue + + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index f8fff2ccb..81ba858d6 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -266,7 +266,8 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): # conv_state = (..., dim, width-1) yet contiguous along 'dim' conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - state_indices_tensor = attn_metadata.state_indices_tensor + state_indices_tensor_p = attn_metadata.state_indices_tensor_p + state_indices_tensor_d = attn_metadata.state_indices_tensor_d has_initial_states_p = attn_metadata.has_initial_states_p prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size @@ -309,13 +310,6 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): gate_d, gate_p = torch.split( gate[:num_actual_tokens], [num_decodes, num_prefill_tokens], dim=0 ) - # Split along batch dimension - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor, - [num_decodes, num_prefills], - dim=0, - ) - # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs preallocated_ssm_out = torch.empty( @@ -336,7 +330,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): if has_prefill: # 2. Convolution sequence transformation # - "cache_indices" updates the conv_state cache in positions - # pointed to by "state_indices_tensor" + # pointed to by "state_indices_tensor_p" x = hidden_states_p.transpose(0, 1) # this is the form that causal-conv see hidden_states_p = causal_conv1d_fn( x, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 329411d62..7d9fc0226 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -522,6 +522,7 @@ _SPECULATIVE_DECODING_MODELS = { "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"), + "NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"), "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"), diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py index 86c117fd9..ed62b5d29 100644 --- a/vllm/transformers_utils/configs/nemotron_h.py +++ b/vllm/transformers_utils/configs/nemotron_h.py @@ -51,6 +51,8 @@ class NemotronHConfig(PretrainedConfig): The pattern of the hybrid model. The pattern is a string of characters where each character represents M: Mamba2, *: Attention, -: MLP + mtp_hybrid_override_pattern (`str`, *optional*, defaults to `"*E"`): + The pattern of the MTP layers. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer encoder. @@ -150,6 +152,7 @@ class NemotronHConfig(PretrainedConfig): intermediate_size=21504, num_hidden_layers=52, hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-", + mtp_hybrid_override_pattern="*E", num_attention_heads=32, head_dim=128, num_key_value_heads=8, # nemo: num_query_groups @@ -203,6 +206,7 @@ class NemotronHConfig(PretrainedConfig): self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.hybrid_override_pattern = hybrid_override_pattern + self.mtp_hybrid_override_pattern = mtp_hybrid_override_pattern self.num_attention_heads = num_attention_heads self.head_dim = head_dim self.sliding_window = sliding_window @@ -215,10 +219,9 @@ class NemotronHConfig(PretrainedConfig): assert len(self.hybrid_override_pattern) == self.num_hidden_layers, ( "hybrid_override_pattern must have same length as num_hidden_layers" ) - assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), ( - "hybrid_override_pattern must only contain characters 'M', '*', or '-'" + assert re.match(r"^[*-ME]+$", self.hybrid_override_pattern), ( + "hybrid_override_pattern must only contain characters 'M', '*', '-', or 'E'" ) - # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 08e543736..94587c3d6 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from dataclasses import dataclass, replace +from typing import Any import torch @@ -200,8 +201,11 @@ class Mamba2AttentionMetadataBuilder( common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, + **kwargs: Any, ) -> Mamba2AttentionMetadata: - common = self._compute_common_metadata(common_attn_metadata) + common = self._compute_common_metadata( + common_attn_metadata, num_accepted_tokens=kwargs.get("num_accepted_tokens") + ) seq_idx_p = None cu_chunk_seqlen_p = None diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 286a34f99..c4ffb16f5 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -2,9 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import abc -import copy -from dataclasses import dataclass -from typing import ClassVar, TypeVar +from dataclasses import dataclass, replace +from typing import Any, ClassVar, TypeVar import torch @@ -35,12 +34,21 @@ class BaseMambaAttentionMetadata: num_reqs: int # The following tensors only contain prefill requests and will be None if - # the batch has no prefill request. + # the batch has no prefill requests. has_initial_states_p: torch.Tensor | None query_start_loc_p: torch.Tensor | None num_computed_tokens_p: torch.Tensor | None + state_indices_tensor_p: torch.Tensor | None - state_indices_tensor: torch.Tensor + # The following tensors are used for decode requests and + # speculative decoding compatibility, and will be None if the batch + # has no decode requests. + state_indices_tensor_d: torch.Tensor | None + query_start_loc_d: torch.Tensor | None # shape: [num_decodes + 1,] + + # Number of accepted tokens for each spec sequence (for loading correct checkpoint) + # Includes the bonus token (so minimum is 1) + num_accepted_tokens: torch.Tensor | None # shape: [batch,] # The following tensors are only used for prefix caching in all mode and # are None if disabled @@ -60,9 +68,9 @@ class BaseMambaAttentionMetadata: class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): metadata_cls: type[M] reorder_batch_threshold: int = 1 - _cudagraph_support: ClassVar[AttentionCGSupport] = ( - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - ) + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + + # Will be disabled if speculative decoding is used supports_update_block_table: bool = True def __init__( @@ -74,6 +82,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) + # Enable speculative decoding support + self.speculative_config = vllm_config.speculative_config + self.compilation_config = vllm_config.compilation_config + self.num_spec_tokens: int = vllm_config.num_speculative_tokens + self.use_spec_decode = self.num_spec_tokens > 0 + assert isinstance(kv_cache_spec, MambaSpec) self.compilation_config = vllm_config.compilation_config self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs @@ -84,13 +98,17 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ) if self.vllm_config.cache_config.mamba_cache_mode == "all": - self.state_indices_tensor = torch.empty( + max_num_blocks = cdiv( + self.vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size, + ) + # Speculative decoding not supported with prefix caching, + # so keep shape consistent with prefill buffer + # TODO: reduce this size as needed for decode-only cudagraph capture + self.state_indices_tensor_d = torch.empty( ( self.decode_cudagraph_max_bs, - cdiv( - self.vllm_config.model_config.max_model_len, - self.kv_cache_spec.block_size, - ), + max_num_blocks, ), dtype=torch.int32, device=device, @@ -106,12 +124,25 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): device=device, ) else: - self.state_indices_tensor = torch.empty( + self.state_indices_tensor_d = torch.empty( + (self.decode_cudagraph_max_bs, 1 + self.num_spec_tokens), + dtype=torch.int32, + device=device, + ) + + # For speculative decoding, we need to store the following buffers + # for CUDA graph capture during decode + if self.num_spec_tokens > 0: + self.decode_num_accepted_tokens = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) + self._init_reorder_batch_threshold(1, self.use_spec_decode) + if self.use_spec_decode: + self.supports_update_block_table = False + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> M: @@ -121,26 +152,38 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): """ m = common_attn_metadata - assert m.num_reqs == m.num_actual_tokens, ( + assert ( + m.max_query_len <= 1 + self.num_spec_tokens + and m.num_reqs <= self.decode_cudagraph_max_bs + ), ( "Mamba only supports decode-only full CUDAGraph capture. " "Make sure all cudagraph capture sizes <= max_num_seq." ) - m.max_query_len = 1 # decode-only + assert m.max_query_len == 1 + self.num_spec_tokens # decode-only - return self.build(0, m) + num_accepted_tokens = None + if self.num_spec_tokens > 0: + num_accepted_tokens = torch.diff(m.query_start_loc) + + return self.build(0, m, num_accepted_tokens=num_accepted_tokens) def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, + *, + num_accepted_tokens: torch.Tensor | None = None, + **kwargs: Any, ) -> M: """ Default build implementation for Mamba-like attention backends. Subclasses (e.g., Mamba2) can override to add additional metadata. """ - return self._compute_common_metadata(common_attn_metadata) + return self._compute_common_metadata( + common_attn_metadata, num_accepted_tokens=num_accepted_tokens + ) def _compute_prefix_caching_block_indices( self, @@ -176,21 +219,32 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): def _compute_common_metadata( self, common_attn_metadata: CommonAttentionMetadata, + *, + num_accepted_tokens: torch.Tensor | None = None, ) -> M: """ Compute metadata common to both Mamba1 and Mamba2. """ num_reqs = common_attn_metadata.num_reqs + # Treat multi-token queries as decode requests when + # speculative decoding is enabled. Otherwise, use the + # default decode threshold to prevent misclassification + # of prefill queries as decode requests. + decode_threshold = ( + self.reorder_batch_threshold if num_accepted_tokens is not None else 1 + ) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold + common_attn_metadata, decode_threshold=decode_threshold ) ) # Need flags to indicate if there are initial states has_initial_states_p = None query_start_loc_p = None + query_start_loc_d = None num_computed_tokens = None num_computed_tokens_p = None @@ -218,13 +272,31 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): common_attn_metadata, mamba_block_size ) else: - # Always return just a single block per each request: state_indices_tensor = mamba_get_block_table_tensor( common_attn_metadata.block_table_tensor, common_attn_metadata.seq_lens, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, - )[:, 0] + ) + + if state_indices_tensor.dim() == 1: + state_indices_tensor = state_indices_tensor.unsqueeze(-1) + + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + if self.vllm_config.cache_config.mamba_cache_mode != "all": + state_indices_tensor_d = state_indices_tensor_d[ + :, : 1 + self.num_spec_tokens + ] + state_indices_tensor_p = state_indices_tensor_p[:, 0] + + if num_decodes > 0 and self.use_spec_decode: + assert num_accepted_tokens is not None + query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1] + num_accepted_tokens = num_accepted_tokens[:num_decodes] if num_prefills > 0: if num_computed_tokens is None: @@ -258,39 +330,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[ num_reqs - num_prefills : num_reqs ] - elif ( - num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.cudagraph_mode.has_full_cudagraphs() - ): - self.state_indices_tensor[:num_decodes].copy_( - state_indices_tensor, non_blocking=True - ) - state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] - state_indices_tensor[num_decodes:] = PAD_SLOT_ID - if self.vllm_config.cache_config.mamba_cache_mode == "all": - self.block_idx_last_scheduled_token[:num_decodes].copy_( - block_idx_last_scheduled_token, non_blocking=True - ) - block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ - :num_decode_tokens - ] - - self.block_idx_last_computed_token[:num_decodes].copy_( - block_idx_last_computed_token, non_blocking=True - ) - block_idx_last_computed_token = self.block_idx_last_computed_token[ - :num_decode_tokens - ] - - return self.metadata_cls( + metadata = self.metadata_cls( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, query_start_loc_p=query_start_loc_p, has_initial_states_p=has_initial_states_p, - state_indices_tensor=state_indices_tensor, + state_indices_tensor_p=state_indices_tensor_p, + state_indices_tensor_d=state_indices_tensor_d, + num_accepted_tokens=num_accepted_tokens, + query_start_loc_d=query_start_loc_d, block_idx_last_scheduled_token=block_idx_last_scheduled_token, block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, block_idx_last_computed_token=block_idx_last_computed_token, @@ -302,55 +353,112 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): token_chunk_offset_ptr=token_chunk_offset_ptr, ) + return self._update_metadata_for_cudagraph_capture(metadata) + + def _update_metadata_for_cudagraph_capture( + self, + metadata: M, + ) -> M: + """ + Update the metadata for cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + state_indices_tensor_d = metadata.state_indices_tensor_d + query_start_loc_d = metadata.query_start_loc_d + num_accepted_tokens = metadata.num_accepted_tokens + block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token + block_idx_last_computed_token = metadata.block_idx_last_computed_token + if ( + metadata.num_prefills == 0 + and metadata.num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ): + padded_bs = metadata.num_reqs + self.state_indices_tensor_d[: metadata.num_decodes].copy_( + state_indices_tensor_d, non_blocking=True + ) + state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs] + state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID + + if self.use_spec_decode: + assert query_start_loc_d is not None + assert num_accepted_tokens is not None + query_start_loc_d = query_start_loc_d[: padded_bs + 1] + self.decode_num_accepted_tokens[: metadata.num_decodes].copy_( + num_accepted_tokens, non_blocking=True + ) + num_accepted_tokens = self.decode_num_accepted_tokens[:padded_bs] + num_accepted_tokens[metadata.num_decodes :] = ( + 1 # pad with 1st slot index + ) + + if self.vllm_config.cache_config.mamba_cache_mode == "all": + assert block_idx_last_scheduled_token is not None + assert block_idx_last_computed_token is not None + self.block_idx_last_scheduled_token[: metadata.num_decodes].copy_( + block_idx_last_scheduled_token[: metadata.num_decodes], + non_blocking=True, + ) + block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ + : metadata.num_decode_tokens + ] + + self.block_idx_last_computed_token[: metadata.num_decodes].copy_( + block_idx_last_computed_token[: metadata.num_decodes], + non_blocking=True, + ) + block_idx_last_computed_token = self.block_idx_last_computed_token[ + : metadata.num_decode_tokens + ] + + return replace( + metadata, + state_indices_tensor_d=state_indices_tensor_d, + query_start_loc_d=query_start_loc_d, + num_accepted_tokens=num_accepted_tokens, + block_idx_last_scheduled_token=block_idx_last_scheduled_token, + block_idx_last_computed_token=block_idx_last_computed_token, + ) + def update_block_table( self, metadata: M, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ) -> M: - new_metadata = copy.copy(metadata) - state_indices_t = mamba_get_block_table_tensor( + state_indices_tensor = mamba_get_block_table_tensor( blk_table, metadata.seq_lens, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, ) - if self.vllm_config.cache_config.mamba_cache_mode in ("none", "align"): - # Only needs the block that saves the running state - state_indices_t = state_indices_t[:, 0] + if state_indices_tensor.dim() == 1: + state_indices_tensor = state_indices_tensor.unsqueeze(-1) - num_reqs = blk_table.shape[0] + assert ( + metadata.num_prefills + metadata.num_decodes + == state_indices_tensor.shape[0] + ), ( + "Mismatch in number of requests when updating block table." + f" Expected {metadata.num_prefills + metadata.num_decodes}, " + f"got {state_indices_tensor.shape[0]}." + ) - # For CUDA graphs, copy to persistent buffer - if ( - metadata.num_prefills == 0 - and num_reqs <= self.decode_cudagraph_max_bs - and self.compilation_config.cudagraph_mode.has_full_cudagraphs() - ): - persistent_state_indices_t = self.state_indices_tensor[:num_reqs] - persistent_state_indices_t.copy_(state_indices_t, non_blocking=True) - state_indices_t = persistent_state_indices_t + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [metadata.num_decodes, metadata.num_prefills], + dim=0, + ) + if self.vllm_config.cache_config.mamba_cache_mode != "all": + state_indices_tensor_d = state_indices_tensor_d[ + :, : 1 + self.num_spec_tokens + ] + state_indices_tensor_p = state_indices_tensor_p[:, 0] - # For 'all' mode, also update prefix caching block indices - # to use this builder's persistent buffers (required for CUDA - # graph replay to read from the correct memory addresses). - if self.vllm_config.cache_config.mamba_cache_mode == "all": - assert metadata.block_idx_last_scheduled_token is not None - assert metadata.block_idx_last_computed_token is not None - self.block_idx_last_scheduled_token[:num_reqs].copy_( - metadata.block_idx_last_scheduled_token[:num_reqs], - non_blocking=True, - ) - new_metadata.block_idx_last_scheduled_token = ( - self.block_idx_last_scheduled_token[: metadata.num_decode_tokens] - ) - self.block_idx_last_computed_token[:num_reqs].copy_( - metadata.block_idx_last_computed_token[:num_reqs], - non_blocking=True, - ) - new_metadata.block_idx_last_computed_token = ( - self.block_idx_last_computed_token[: metadata.num_decode_tokens] - ) + new_metadata = replace( + metadata, + state_indices_tensor_d=state_indices_tensor_d, + state_indices_tensor_p=state_indices_tensor_p, + ) - new_metadata.state_indices_tensor = state_indices_t - return new_metadata + return self._update_metadata_for_cudagraph_capture(new_metadata) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0013ec3d7..99b799ea4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -113,6 +113,7 @@ from vllm.v1.attention.backend import ( MultipleOf, ) from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( create_fast_prefill_custom_backend, get_dcp_local_seq_lens, @@ -1852,7 +1853,9 @@ class GPUModelRunner( ) extra_attn_metadata_args = {} - if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): + if use_spec_decode and isinstance( + builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder) + ): assert ubid is None, "UBatching not supported with GDN yet" extra_attn_metadata_args = dict( num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded], @@ -4725,7 +4728,7 @@ class GPUModelRunner( # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. - assert num_tokens <= self.scheduler_config.max_num_batched_tokens + assert num_tokens <= self.max_num_tokens max_num_reqs = self.scheduler_config.max_num_seqs if create_mixed_batch: assert not uniform_decode @@ -4849,6 +4852,7 @@ class GPUModelRunner( ubatch_slices=(ubatch_slices_padded if pad_attn else ubatch_slices), for_cudagraph_capture=is_graph_capturing, slot_mappings=slot_mappings_by_group, + use_spec_decode=self.speculative_config is not None, ) with self.maybe_dummy_run_with_lora(