# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Qwen3Next model.""" from collections.abc import Iterable from itertools import islice import torch from einops import rearrange from torch import nn from transformers.activations import ACT2FN from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CacheConfig, ModelConfig, SpeculativeConfig, VllmConfig, get_current_vllm_config, ) from vllm.distributed import ( divide, get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fla.ops import ( chunk_gated_delta_rule as fla_chunk_gated_delta_rule, ) from vllm.model_executor.layers.fla.ops import ( fused_recurrent_gated_delta_rule, ) from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3NextRMSNorm, ) from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateCopyFunc, MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, sharded_weight_loader, ) from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from .interfaces import ( HasInnerState, IsHybrid, MixtureOfExperts, SupportsLoRA, SupportsPP, ) from .utils import ( AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) logger = init_logger(__name__) KVCache = tuple[torch.Tensor, torch.Tensor] def fi_chunk_gated_delta_rule( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, use_qk_l2norm_in_kernel: bool = True, ): from flashinfer.gdn_prefill import ( chunk_gated_delta_rule as chunk_gated_delta_rule_fi, ) if use_qk_l2norm_in_kernel: q = l2norm_fwd(q) k = l2norm_fwd(k) # use flashinfer implementation q = q.squeeze(0).contiguous() k = k.squeeze(0).contiguous() v = v.squeeze(0).contiguous() g = g.squeeze(0).contiguous() beta = beta.squeeze(0).contiguous() fi_state = initial_state.to(torch.float32) fi_g = g.to(torch.float32) fi_beta = beta.to(torch.float32) output, final_state = chunk_gated_delta_rule_fi( q=q, k=k, v=v, g=torch.exp(fi_g), beta=fi_beta, initial_state=fi_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, ) # Unsqueeze back to 4D (1, L, H, D) to match fla output format return output.unsqueeze(0), final_state @CustomOp.register("chunk_gated_delta_rule") class ChunkGatedDeltaRule(CustomOp): def __init__(self) -> None: super().__init__() if current_platform.is_cuda() and current_platform.is_device_capability(90): logger.info_once( "Using FlashInfer GDN prefill kernel on CUDA compute capability 90" ) self._forward_method = self.forward_cuda else: self._forward_method = self.forward_native def forward_cuda( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, use_qk_l2norm_in_kernel: bool = True, ): return fi_chunk_gated_delta_rule( q=q, k=k, v=v, g=g, beta=beta, initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, ) def forward_native( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, use_qk_l2norm_in_kernel: bool = True, ): return fla_chunk_gated_delta_rule( q=q, k=k, v=v, g=g, beta=beta, initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, ) class Qwen3NextSparseMoeBlock(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_text_config parallel_config = vllm_config.parallel_config quant_config = vllm_config.quant_config self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}." ) # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = parallel_config.enable_eplb self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size self.physical_expert_start = self.ep_rank * self.n_local_physical_experts self.physical_expert_end = ( self.physical_expert_start + self.n_local_physical_experts ) self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate", ) self.shared_expert_gate = ReplicatedLinear( config.hidden_size, 1, bias=False, quant_config=None, prefix=f"{prefix}.shared_expert_gate", ) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen3NextMLP( hidden_size=config.hidden_size, intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, expert_gate=self.shared_expert_gate, prefix=f"{prefix}.shared_expert", ) else: self.shared_expert = None self.experts = SharedFusedMoE( shared_experts=self.shared_expert, gate=self.gate, num_experts=self.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=getattr(config, "norm_topk_prob", True), quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) if self.experts.is_internal_router: # In this case, the gate/router runs inside the FusedMoE class final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=hidden_states ) else: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) if self.shared_expert is not None: final_hidden_states = final_hidden_states[0] + final_hidden_states[1] if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states, 0 ) final_hidden_states = final_hidden_states[:num_tokens] elif self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states ) return final_hidden_states.view(orig_shape) class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): @property def mamba_type(self) -> str: return "gdn_attention" def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( self.model_config.dtype, self.cache_config.mamba_cache_dtype, self.cache_config.mamba_ssm_cache_dtype, ) def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.gated_delta_net_state_shape( self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim, self.head_v_dim, self.conv_kernel_size, self.num_spec, ) def __init__( self, config: Qwen3NextConfig, model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, speculative_config: SpeculativeConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.hidden_size = config.hidden_size self.num_v_heads = config.linear_num_value_heads self.num_k_heads = config.linear_num_key_heads self.head_k_dim = config.linear_key_head_dim self.head_v_dim = config.linear_value_head_dim self.key_dim = self.head_k_dim * self.num_k_heads self.value_dim = self.head_v_dim * self.num_v_heads self.conv_kernel_size = config.linear_conv_kernel_dim self.layer_idx = extract_layer_index(prefix) self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] self.layer_norm_epsilon = config.rms_norm_eps self.prefix = prefix self.config = config self.model_config = model_config self.cache_config = cache_config self.quant_config = quant_config self.speculative_config = speculative_config self.num_spec = ( self.speculative_config.num_speculative_tokens if self.speculative_config else 0 ) # QKV self.conv_dim = self.key_dim * 2 + self.value_dim self.conv1d = ColumnParallelLinear( input_size=self.conv_kernel_size, output_size=self.conv_dim, bias=False, prefix=f"{prefix}.conv1d", ) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) # projection of the input hidden states # Qwen3-Next and Qwen3.5 has a different qkv_proj layout, # we need to create qkvz_proj adaptively here. self.in_proj_qkvz = self.create_qkvz_proj( hidden_size=self.hidden_size, key_dim=self.key_dim, value_dim=self.value_dim, quant_config=quant_config, prefix=f"{prefix}.in_proj_qkvz", ) # ba_proj doesn't support blockwise fp8 quantization. self.in_proj_ba = MergedColumnParallelLinear( input_size=self.hidden_size, output_sizes=[self.num_v_heads] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.in_proj_ba", ) query_key_settings = (self.key_dim, 0, False) value_settings = (self.value_dim, 0, False) delattr(self.conv1d.weight, "weight_loader") set_weight_attrs( self.conv1d.weight, { "weight_loader": mamba_v2_sharded_weight_loader( [ query_key_settings, query_key_settings, value_settings, ], self.tp_size, self.tp_rank, ) }, ) # selective projection used to make dt, B and C input dependant # time step projection (discretization) # instantiate once and copy inv_dt in init_weights of PretrainedModel self.dt_bias = nn.Parameter( torch.ones(self.num_v_heads // self.tp_size), ) self.A_log = nn.Parameter( torch.empty( divide(self.num_v_heads, self.tp_size), ) ) set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) self.norm = RMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, group_size=None, norm_before_gate=True, device=current_platform.current_device(), dtype=config.dtype, ) self.out_proj = RowParallelLinear( self.value_dim, self.hidden_size, bias=False, input_is_parallel=True, quant_config=quant_config, prefix=f"{prefix}.out_proj", ) self.chunk_gated_delta_rule = ChunkGatedDeltaRule() compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self def create_qkvz_proj( self, hidden_size: int, key_dim: int, value_dim: int, quant_config: QuantizationConfig | None, prefix: str, ) -> MergedColumnParallelLinear: return MergedColumnParallelLinear( input_size=hidden_size, output_sizes=[sum((key_dim, key_dim, value_dim, value_dim))], bias=False, quant_config=quant_config, prefix=prefix, ) def fix_query_key_value_ordering( self, mixed_qkvz: torch.Tensor, mixed_ba: torch.Tensor, ): """ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. """ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( self.num_k_heads // self.tp_size, ( self.head_k_dim + self.head_k_dim + (self.head_v_dim + self.head_v_dim) * self.num_v_heads // self.num_k_heads ), ) new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( self.num_k_heads // self.tp_size, 2 * self.num_v_heads // self.num_k_heads, ) mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) mixed_ba = mixed_ba.view(*new_tensor_shape_ba) split_arg_list_qkvz = [ self.head_k_dim, self.head_k_dim, (self.num_v_heads // self.num_k_heads * self.head_v_dim), (self.num_v_heads // self.num_k_heads * self.head_v_dim), ] split_arg_list_ba = [ self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads, ] # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] value = value.reshape(value.size(0), -1, self.head_v_dim) z = z.reshape(z.size(0), -1, self.head_v_dim) b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) return query, key, value, z, b, a def rearrange_mixed_qkv(self, mixed_qkv): if mixed_qkv is None: return None, None, None query, key, value = torch.split( mixed_qkv, [ self.key_dim // self.tp_size, self.key_dim // self.tp_size, self.value_dim // self.tp_size, ], dim=-1, ) query, key = map( lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), (query, key), ) value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) return query.contiguous(), key.contiguous(), value.contiguous() def forward( self, hidden_states: torch.Tensor, output: torch.Tensor, ): """ Forward pass with three parts: 1. Input projection 2. Core attention (custom op) 3. Output projection """ num_tokens = hidden_states.size(0) # ============================================================ # Part 1: Input Projection # ============================================================ projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) projected_states_ba, _ = self.in_proj_ba(hidden_states) query, key, value, z, b, a = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba ) query, key, value = map( lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value) ) mixed_qkv = torch.cat((query, key, value), dim=-1) # ============================================================ # Part 2: Core Attention (Custom Op) # ============================================================ # Note: we should not use torch.empty here like other attention backends, # see discussions in https://github.com/vllm-project/vllm/pull/28182 core_attn_out = torch.zeros( (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), dtype=hidden_states.dtype, device=hidden_states.device, ) torch.ops.vllm.gdn_attention_core( mixed_qkv, b, a, core_attn_out, self.prefix, ) # ============================================================ # Part 3: Output Projection # ============================================================ z_shape_og = z.shape # Reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_tokens], _ = self.out_proj(core_attn_out) def _forward_core( self, mixed_qkv: torch.Tensor, b: torch.Tensor, a: torch.Tensor, core_attn_out: torch.Tensor, ): """ Core attention computation (called by custom op). """ forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is None: # V1 profile run return assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, GDNAttentionMetadata) has_initial_state = attn_metadata.has_initial_state spec_query_start_loc = attn_metadata.spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc spec_sequence_masks = attn_metadata.spec_sequence_masks spec_token_indx = attn_metadata.spec_token_indx non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens mixed_qkv = mixed_qkv[:num_actual_tokens] b = b[:num_actual_tokens] a = a[:num_actual_tokens] # 1. Convolution sequence transformation conv_weights = self.conv1d.weight.view( self.conv1d.weight.size(0), self.conv1d.weight.size(2) ) if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: mixed_qkv_spec = mixed_qkv mixed_qkv_non_spec = None else: mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) else: mixed_qkv_spec = None mixed_qkv_non_spec = mixed_qkv # 1.1: Process the multi-query part if spec_sequence_masks is not None: mixed_qkv_spec = causal_conv1d_update( mixed_qkv_spec, conv_state, conv_weights, self.conv1d.bias, self.activation, conv_state_indices=spec_state_indices_tensor[:, 0][ : attn_metadata.num_spec_decodes ], num_accepted_tokens=num_accepted_tokens, query_start_loc=spec_query_start_loc, max_query_len=spec_state_indices_tensor.size(-1), validate_data=False, ) # 1.2: Process the remaining part if attn_metadata.num_prefills > 0: mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) # - "cache_indices" updates the conv_state cache in positions # pointed to by "state_indices_tensor" mixed_qkv_non_spec = causal_conv1d_fn( mixed_qkv_non_spec_T, conv_weights, self.conv1d.bias, activation=self.activation, conv_states=conv_state, has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: mixed_qkv_non_spec = causal_conv1d_update( mixed_qkv_non_spec, conv_state, conv_weights, self.conv1d.bias, self.activation, conv_state_indices=non_spec_state_indices_tensor[ : attn_metadata.num_actual_tokens ], validate_data=True, ) else: mixed_qkv_non_spec = None query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( mixed_qkv_non_spec ) g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: g_spec = g beta_spec = beta g_non_spec = None beta_non_spec = None else: g_spec = g.index_select(1, spec_token_indx) beta_spec = beta.index_select(1, spec_token_indx) g_non_spec = g.index_select(1, non_spec_token_indx) beta_non_spec = beta.index_select(1, non_spec_token_indx) else: g_spec = None beta_spec = None g_non_spec = g beta_non_spec = beta # 2. Recurrent attention # 2.1: Process the multi-query part if spec_sequence_masks is not None: core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( q=query_spec, k=key_spec, v=value_spec, g=g_spec, beta=beta_spec, initial_state=ssm_state, inplace_final_state=True, cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], ssm_state_indices=spec_state_indices_tensor, num_accepted_tokens=num_accepted_tokens, use_qk_l2norm_in_kernel=True, ) else: core_attn_out_spec, last_recurrent_state = None, None # 2.2: Process the remaining part if attn_metadata.num_prefills > 0: initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 ( core_attn_out_non_spec, last_recurrent_state, ) = self.chunk_gated_delta_rule( q=query_non_spec, k=key_non_spec, v=value_non_spec, g=g_non_spec, beta=beta_non_spec, initial_state=initial_state, output_final_state=True, cu_seqlens=non_spec_query_start_loc, use_qk_l2norm_in_kernel=True, ) # Init cache ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( ssm_state.dtype ) elif attn_metadata.num_decodes > 0: core_attn_out_non_spec, last_recurrent_state = ( fused_recurrent_gated_delta_rule( q=query_non_spec, k=key_non_spec, v=value_non_spec, g=g_non_spec, beta=beta_non_spec, initial_state=ssm_state, inplace_final_state=True, cu_seqlens=non_spec_query_start_loc[ : attn_metadata.num_decodes + 1 ], ssm_state_indices=non_spec_state_indices_tensor, use_qk_l2norm_in_kernel=True, ) ) else: core_attn_out_non_spec, last_recurrent_state = None, None # 3. Merge core attention output if spec_sequence_masks is not None and core_attn_out_non_spec is not None: merged_out = torch.empty( (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), dtype=core_attn_out_non_spec.dtype, device=core_attn_out_non_spec.device, ) merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) elif spec_sequence_masks is not None: core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) else: core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) class Qwen3NextAttention(nn.Module): def __init__( self, config: Qwen3NextConfig, model_config: ModelConfig | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = config.num_key_value_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.dual_chunk_attention_config = getattr( config, "dual_chunk_attention_config", None ) self.attn_output_gate = getattr(config, "attn_output_gate", True) self.qkv_proj = QKVParallelLinear( config.hidden_size, self.head_dim, self.total_num_heads * (1 + self.attn_output_gate), self.total_num_kv_heads, bias=getattr(config, "qkv_bias", False), quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( head_size=self.head_dim, max_position=config.max_position_embeddings, rope_parameters=config.rope_parameters, dual_chunk_attention_config=self.dual_chunk_attention_config, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": self.dual_chunk_attention_config, } if self.dual_chunk_attention_config else {}, ) self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, output: torch.Tensor, hidden_states: torch.Tensor, ): qkv, _ = self.qkv_proj(hidden_states) if self.attn_output_gate: q_gate, k, v = qkv.split( [self.q_size * 2, self.kv_size, self.kv_size], dim=-1 ) orig_shape = q_gate.shape[:-1] q_gate = q_gate.view(*orig_shape, self.num_heads, -1) q, gate = torch.chunk(q_gate, 2, dim=-1) q = q.reshape(*orig_shape, -1) gate = gate.reshape(*orig_shape, -1) else: q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( -1, self.num_heads * self.head_dim ) k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( -1, self.num_kv_heads * self.head_dim ) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) if self.attn_output_gate: gate = torch.sigmoid(gate) attn_output = attn_output * gate output[:], _ = self.o_proj(attn_output) class Qwen3NextDecoderLayer(nn.Module): def __init__( self, vllm_config: VllmConfig, layer_type: str, prefix: str = "", ) -> None: super().__init__() config = vllm_config.model_config.hf_config model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config speculative_config = vllm_config.speculative_config self.layer_type = layer_type self.layer_idx = extract_layer_index(prefix) if self.layer_type == "linear_attention": self.linear_attn = Qwen3NextGatedDeltaNet( config, model_config=model_config, cache_config=cache_config, quant_config=quant_config, speculative_config=speculative_config, prefix=f"{prefix}.linear_attn", ) elif self.layer_type == "full_attention": self.self_attn = Qwen3NextAttention( config, model_config=model_config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) else: raise ValueError(f"Invalid layer_type {self.layer_type}") mlp_only_layers = ( [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers ) if (self.layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (self.layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3NextSparseMoeBlock( vllm_config=vllm_config, prefix=f"{prefix}.mlp", ) else: self.mlp = Qwen3NextMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.mlp", ) self.input_layernorm = Qwen3NextRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_attention_layernorm = Qwen3NextRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.layer_scale = getattr(config, "layer_scale", False) if self.layer_scale: self.attn_layer_scale = torch.nn.Parameter( torch.zeros( 1, 1, config.hidden_size, dtype=config.dtype, ), ) self.ffn_layer_scale = torch.nn.Parameter( torch.zeros( 1, 1, config.hidden_size, dtype=config.dtype, ), ) def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor | None, positions: torch.Tensor = None, **kwargs: object, ): if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) self_attention_output = torch.empty_like(hidden_states) if self.layer_type == "linear_attention": self.linear_attn( hidden_states=hidden_states, output=self_attention_output, ) elif self.layer_type == "full_attention": self.self_attn( hidden_states=hidden_states, output=self_attention_output, positions=positions, ) else: raise ValueError("Invalid layer_type") hidden_states = self_attention_output if self.layer_scale: if len(hidden_states.shape) == 2: hidden_states = hidden_states * ( self.attn_layer_scale.to(hidden_states.dtype)[0] + 1 ) else: hidden_states = hidden_states * ( self.attn_layer_scale.to(hidden_states.dtype) + 1 ) # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) if self.layer_scale: if len(hidden_states.shape) == 2: hidden_states = hidden_states * ( self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1 ) else: assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), ( f"shape must be the same {len(hidden_states.shape)}, " f"{len(self.ffn_layer_scale.shape)}" ) hidden_states = hidden_states * ( self.ffn_layer_scale.to(hidden_states.dtype) + 1 ) return hidden_states, residual @support_torch_compile class Qwen3NextModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: Qwen3NextConfig = vllm_config.model_config.hf_text_config parallel_config = vllm_config.parallel_config eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, ) def get_layer(prefix: str): return Qwen3NextDecoderLayer( vllm_config, layer_type=config.layer_types[extract_layer_index(prefix)], prefix=prefix, ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" ) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) if get_pp_group().is_last_rank: self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, ) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=getattr(self.config, "num_experts", 0), num_redundant_experts=self.num_redundant_experts, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if name.startswith("mtp."): continue # Remapping the name of FP8 kv-scale. if name.endswith("scale"): name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue if "mlp.experts" in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue # name = apply_attn_prefix(name, params_dict) if name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. if ( name.endswith(".bias") or name.endswith("_bias") ) and name not in params_dict: continue if name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue if name not in params_dict: logger.warning_once( f"Parameter {name} not found in params_dict, skip loading" ) continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class QwenNextMixtureOfExperts(MixtureOfExperts): def update_physical_experts_metadata( self, num_physical_experts: int, num_local_physical_experts: int, ) -> None: assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): moe = layer.mlp moe.n_local_physical_experts = num_local_physical_experts moe.n_physical_experts = num_physical_experts moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() def set_moe_parameters(self): self.expert_weights = [] self.moe_layers = [] example_moe = None for layer in self.model.layers: if isinstance(layer, Qwen3NextDecoderLayer) and isinstance( layer.mlp, Qwen3NextSparseMoeBlock ): example_moe = layer.mlp self.moe_layers.append(layer.mlp.experts) if example_moe is None: raise RuntimeError("No Qwen3Next layer found in the model.layers.") # Set MoE hyperparameters self.num_moe_layers = len(self.moe_layers) self.num_expert_groups = 1 self.num_shared_experts = 0 self.num_logical_experts = example_moe.n_logical_experts self.num_physical_experts = example_moe.n_physical_experts self.num_local_physical_experts = example_moe.n_local_physical_experts self.num_routed_experts = example_moe.n_routed_experts self.num_redundant_experts = example_moe.n_redundant_experts class Qwen3NextForCausalLM( nn.Module, HasInnerState, SupportsLoRA, SupportsPP, QwenNextMixtureOfExperts, IsHybrid, ): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_text_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config scheduler_config = vllm_config.scheduler_config if cache_config.mamba_cache_mode == "all": raise NotImplementedError( "Qwen3Next currently does not support 'all' prefix caching, " "please use '--mamba-cache-mode=align' instead" ) self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config self.model = Qwen3NextModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) # Set MoE hyperparameters self.set_moe_parameters() def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ): hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return hidden_states @classmethod def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, vllm_config.cache_config.mamba_ssm_cache_dtype, ) @classmethod def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig" ) -> tuple[tuple[int, int], tuple[int, int]]: parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_text_config tp_size = parallel_config.tensor_parallel_size num_spec = ( vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0 ) return MambaStateShapeCalculator.gated_delta_net_state_shape( tp_size, hf_config.linear_num_key_heads, hf_config.linear_num_value_heads, hf_config.linear_key_head_dim, hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim, num_spec, ) @classmethod def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func() def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.logits_processor(self.lm_head, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=["mtp."], ) return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() def gdn_attention_core( mixed_qkv: torch.Tensor, b: torch.Tensor, a: torch.Tensor, core_attn_out: torch.Tensor, layer_name: str, ) -> None: """ Custom op for the core attention computation. Only handles the convolution + recurrent attention part. Input/output projections are handled outside this op. """ forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] self._forward_core( mixed_qkv=mixed_qkv, b=b, a=a, core_attn_out=core_attn_out, ) def gdn_attention_core_fake( mixed_qkv: torch.Tensor, b: torch.Tensor, a: torch.Tensor, core_attn_out: torch.Tensor, layer_name: str, ) -> None: """Fake implementation for torch.compile.""" return direct_register_custom_op( op_name="gdn_attention_core", op_func=gdn_attention_core, mutates_args=["core_attn_out"], fake_impl=gdn_attention_core_fake, ) @triton.jit def fused_gdn_gating_kernel( g, beta_output, A_log, a, b, dt_bias, seq_len, NUM_HEADS: tl.constexpr, beta: tl.constexpr, threshold: tl.constexpr, BLK_HEADS: tl.constexpr, ): i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off mask = head_off < NUM_HEADS blk_A_log = tl.load(A_log + head_off, mask=mask) blk_a = tl.load(a + off, mask=mask) blk_b = tl.load(b + off, mask=mask) blk_bias = tl.load(dt_bias + head_off, mask=mask) # If the model is loaded in fp16, without the .float() here, A might be -inf x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) softplus_x = tl.where( beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x ) blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) # compute beta_output = sigmoid(b) blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) tl.store( beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask ) def fused_gdn_gating( A_log: torch.Tensor, a: torch.Tensor, b: torch.Tensor, dt_bias: torch.Tensor, beta: float = 1.0, threshold: float = 20.0, ) -> tuple[torch.Tensor, torch.Tensor]: """ Fused computation of g and beta for Gated Delta Net. g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) beta_output = b.sigmoid() TODO maybe use torch.compile to replace this triton kernel """ batch, num_heads = a.shape seq_len = 1 grid = (batch, seq_len, triton.cdiv(num_heads, 8)) g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device) fused_gdn_gating_kernel[grid]( g, beta_output, A_log, a, b, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1, ) return g, beta_output