From 396a83ea56e31ee1244da59a3cd00f33ace1627f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 03:13:38 +0000 Subject: [PATCH] Clean vLLM integration: use official paths, BF16 wo_a, proper weight mapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - deepseek_v4.py: Fresh upstream copy with minimal NVFP4 changes - wo_a uses quant_config=None (BF16 in NVFP4 checkpoint, no scales) - Added _make_deepseek_v4_nvfp4_weights_mapper() using official WeightsMapper API - Handles: self_attn→attn, mlp→ffn, gate_proj→w1, compressor renames, etc. - Mapper selected by quant_config.get_name() == 'modelopt_fp4' - deepseek_v4_attention.py: Fresh upstream copy with minimal NVFP4 changes - Removed _wo_a_act_quant and custom CuTeDSL wo_a runner - Added _apply_inv_rope_bf16() helper (inverse RoPE in BF16) - Detects BF16 wo_a (no weight_scale_inv) and uses BF16 path - FP8 einsum path kept as fallback for SM90 checkpoints - BF16 path: inverse RoPE → wo_a() → wo_b() (standard linear methods) --- vllm/patches/deepseek_v4.py | 523 ++++++++++---------------- vllm/patches/deepseek_v4_attention.py | 108 +++--- 2 files changed, 262 insertions(+), 369 deletions(-) diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 9d9ab6b6..c5f93db7 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -23,11 +23,14 @@ from vllm.model_executor.layers.deepseek_v4_attention import ( DeepseekV4MLAModules, DeepseekV4MultiHeadLatentAttentionWrapper, ) -from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( fused_topk_bias, ) +from vllm.model_executor.layers.fused_moe.router.norm_gate_linear import ( + NormGateLinear, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -35,6 +38,12 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mhc import ( + HCHeadOp, + MHCFusedPostPreOp, + MHCPostOp, + MHCPreOp, +) from vllm.model_executor.layers.quantization import ( QuantizationConfig, QuantizationMethods, @@ -749,23 +758,23 @@ class DeepseekV4MoE(nn.Module): "deep_gemm_mega_moe for this checkpoint." ) - self.gate = GateLinear( - config.hidden_size, - config.n_routed_experts, - out_dtype=torch.float32, - bias=False, - prefix=f"{prefix}.gate", + # Fused RMSNorm + gate: owns both ffn_norm and the gate matmul. + self.norm_gate = NormGateLinear( + hidden_size=config.hidden_size, + num_experts=config.n_routed_experts, + rms_eps=config.rms_norm_eps, + prefix=f"{prefix}.norm_gate", ) - self.gate.e_score_correction_bias = None - self.gate.tid2eid = None + # Routing-side tensors live on ``norm_gate`` directly (not on the + # inner gate); they are initialized to None in NormGatedLinear and + # populated below depending on the MoE variant. is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32 - if is_hash_moe: # hash MoE doesn't use e_score_correction_bias # Use randint instead of empty to avoid garbage values causing # invalid memory access in dummy mode (--load-format="dummy") - self.gate.tid2eid = nn.Parameter( + self.norm_gate.tid2eid = nn.Parameter( torch.randint( 0, config.n_routed_experts, @@ -775,7 +784,7 @@ class DeepseekV4MoE(nn.Module): requires_grad=False, ) elif getattr(config, "topk_method", None) == "noaux_tc": - self.gate.e_score_correction_bias = nn.Parameter( + self.norm_gate.e_score_correction_bias = nn.Parameter( torch.empty(config.n_routed_experts, dtype=torch.float32), requires_grad=False, ) @@ -838,10 +847,9 @@ class DeepseekV4MoE(nn.Module): self.n_local_experts = config.n_routed_experts // self.tp_size self.experts_start_idx = self.tp_rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts - + # We don't pass `gate` into FusedMoE self.experts = FusedMoE( shared_experts=self.shared_experts, - gate=self.gate, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -851,8 +859,8 @@ class DeepseekV4MoE(nn.Module): prefix=f"{prefix}.experts", scoring_func=self.scoring_func, routed_scaling_factor=self.routed_scaling_factor, - e_score_correction_bias=self.gate.e_score_correction_bias, - hash_indices_table=self.gate.tid2eid, + e_score_correction_bias=self.norm_gate.e_score_correction_bias, + hash_indices_table=self.norm_gate.tid2eid, swiglu_limit=self.swiglu_limit, router_logits_dtype=torch.float32, ) @@ -860,40 +868,40 @@ class DeepseekV4MoE(nn.Module): def forward( self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None ) -> torch.Tensor: - if self.gate.tid2eid is not None and input_ids is None: + if self.norm_gate.tid2eid is not None and input_ids is None: raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.") if not self.use_mega_moe: return self._forward_fused_moe(hidden_states, input_ids) org_shape = hidden_states.shape - router_logits, _ = self.gate(hidden_states) + normed_x, router_logits = self.norm_gate(hidden_states) topk_weights, topk_ids = fused_topk_bias( - hidden_states=hidden_states, + hidden_states=normed_x, gating_output=router_logits, scoring_func=self.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias.data - if self.gate.e_score_correction_bias is not None + e_score_correction_bias=self.norm_gate.e_score_correction_bias.data + if self.norm_gate.e_score_correction_bias is not None else None, topk=self.n_activated_experts, renormalize=self.renormalize, indices_type=self.hash_indices_dtype, input_tokens=input_ids, - hash_indices_table=self.gate.tid2eid, + hash_indices_table=self.norm_gate.tid2eid, routed_scaling_factor=self.routed_scaling_factor, ) activation_clamp = ( float(self.swiglu_limit) if self.swiglu_limit is not None else None ) final_hidden_states = self.experts( - hidden_states, + normed_x, topk_weights, topk_ids, activation_clamp=activation_clamp, ) if self.shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + shared_output = self.shared_experts(normed_x) final_hidden_states += shared_output return final_hidden_states.view(org_shape) @@ -901,21 +909,14 @@ class DeepseekV4MoE(nn.Module): def _forward_fused_moe( self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None ) -> torch.Tensor: + assert not self.experts.is_internal_router org_shape = hidden_states.shape - if self.experts.is_internal_router: - # In this case, the gate/router runs inside the FusedMoE class - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=hidden_states, - input_ids=input_ids, - ) - else: - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - input_ids=input_ids, - ) + normed_x, router_logits = self.norm_gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=normed_x, + router_logits=router_logits, + input_ids=input_ids, + ) return final_hidden_states.view(org_shape) @@ -989,10 +990,8 @@ class DeepseekV4Attention(nn.Module): ) self.kv_norm = RMSNorm(self.head_dim, self.eps) - # wo_a is NOT quantized in the NVFP4 checkpoint (modelopt left it as bfloat16), - # but the attention forward pass expects FP8 (weight + weight_scale_inv). - # Pass quant_config=None to load bfloat16, then process_weights_after_loading - # will handle the FP8 quantization. + # wo_a is BF16 in the NVFP4 checkpoint (no quantization scales). + # Pass quant_config=None so it loads as a plain BF16 linear layer. self.wo_a = ColumnParallelLinear( self.n_heads * self.head_dim // self.n_groups, self.n_groups * self.o_lora_rank, @@ -1123,7 +1122,8 @@ class DeepseekV4DecoderLayer(nn.Module): self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn") self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) - self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) + # ``ffn_norm`` is owned by ``self.ffn.norm_gate`` (fused with the + # router gate matmul); see ``NormGatedLinear``. self.hc_mult = config.hc_mult self.hc_sinkhorn_iters = config.hc_sinkhorn_iters self.hc_eps = config.hc_eps @@ -1172,6 +1172,9 @@ class DeepseekV4DecoderLayer(nn.Module): ), requires_grad=False, ) + self.mhc_pre = MHCPreOp() + self.mhc_post = MHCPostOp() + self.mhc_fused_post_pre = MHCFusedPostPreOp() def hc_pre( self, @@ -1180,7 +1183,7 @@ class DeepseekV4DecoderLayer(nn.Module): hc_scale: torch.Tensor, hc_base: torch.Tensor, ): - post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre( + post_mix, res_mix, layer_input = self.mhc_pre( residual=x, fn=hc_fn, hc_scale=hc_scale, @@ -1200,17 +1203,17 @@ class DeepseekV4DecoderLayer(nn.Module): post: torch.Tensor, comb: torch.Tensor, ): - return torch.ops.vllm.mhc_post(x, residual, post, comb) + return self.mhc_post(x, residual, post, comb) - def forward( + def _forward_cuda( self, x: torch.Tensor, positions: torch.Tensor, input_ids: torch.Tensor | None, - post_mix: torch.Tensor | None, - res_mix: torch.Tensor | None, - residual: torch.Tensor | None, - ) -> torch.Tensor: + post_mix: torch.Tensor | None = None, + res_mix: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if residual is None: # Run standalone hc_pre on first layer residual = x @@ -1218,7 +1221,7 @@ class DeepseekV4DecoderLayer(nn.Module): x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base ) else: - residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre( + residual, post_mix, res_mix, x = self.mhc_fused_post_pre( x, residual, post_mix, @@ -1236,7 +1239,7 @@ class DeepseekV4DecoderLayer(nn.Module): x = self.attn_norm(x) x = self.attn(positions, x, None) - residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre( + residual, post_mix, res_mix, x = self.mhc_fused_post_pre( x, residual, post_mix, @@ -1250,11 +1253,58 @@ class DeepseekV4DecoderLayer(nn.Module): self.hc_post_alpha, self.hc_sinkhorn_iters, ) - - x = self.ffn_norm(x) + # ffn_norm is now folded into self.ffn.norm_gate; ffn() takes + # the pre-norm activation directly. x = self.ffn(x, input_ids) return x, residual, post_mix, res_mix + def _forward_rocm( + self, + x: torch.Tensor, + positions: torch.Tensor, + input_ids: torch.Tensor | None, + post_mix: torch.Tensor | None = None, + res_mix: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None + ]: + residual = x + x, post, comb = self.hc_pre( + x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base + ) + x = self.attn_norm(x) + x = self.attn(positions, x, None) + x = self.hc_post(x, residual, post, comb) + + residual = x + x, post, comb = self.hc_pre( + x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base + ) + # ffn_norm is now folded into self.ffn.norm_gate; ffn() takes + # the pre-norm activation directly. + x = self.ffn(x, input_ids) + x = self.hc_post(x, residual, post, comb) + return x, None, None, None + + def forward( + self, + x: torch.Tensor, + positions: torch.Tensor, + input_ids: torch.Tensor | None, + post_mix: torch.Tensor | None = None, + res_mix: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None + ]: + if current_platform.is_rocm(): + return self._forward_rocm( + x, positions, input_ids, post_mix, res_mix, residual + ) + + return self._forward_cuda(x, positions, input_ids, post_mix, res_mix, residual) + @support_torch_compile class DeepseekV4Model(nn.Module): @@ -1262,17 +1312,6 @@ class DeepseekV4Model(nn.Module): super().__init__() config = vllm_config.model_config.hf_config - - # Select weight mapper based on quantization method. - # NVFP4 (modelopt_fp4) checkpoints use different key naming - # than the default MXFP4 format. - quant_config = vllm_config.quant_config - if quant_config is not None and getattr(quant_config, "get_name", lambda: None)() == "modelopt_fp4": - self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper() - elif getattr(config, "expert_dtype", "fp4") != "fp4": - self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8") - else: - self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4") quant_config = vllm_config.quant_config self.config = config self.use_mega_moe = ( @@ -1355,7 +1394,7 @@ class DeepseekV4Model(nn.Module): torch.empty(1, dtype=torch.float32), requires_grad=False, ) - + self.hc_head_op = HCHeadOp() # Pre-hc_head residual stream buffer for the MTP draft. Stable # address (outside the cudagraph pool) so the copy_ in forward() # refreshes it correctly across captured shapes. @@ -1425,7 +1464,7 @@ class DeepseekV4Model(nn.Module): res_mix, residual, ) - else: + if layer is not None and current_platform.is_cuda(): hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix) if not get_pp_group().is_last_rank: @@ -1435,7 +1474,7 @@ class DeepseekV4Model(nn.Module): num_tokens = hidden_states.shape[0] self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1)) - hidden_states = hc_head( + hidden_states = self.hc_head_op( hidden_states, self.hc_head_fn, self.hc_head_scale, @@ -1455,9 +1494,6 @@ class DeepseekV4Model(nn.Module): ("attn.fused_wqa_wkv", "attn.wkv", 1), ("compressor.fused_wkv_wgate", "compressor.wkv", 0), ("compressor.fused_wkv_wgate", "compressor.wgate", 1), - # Indexer's compressor (same stacking pattern) - ("indexer.compressor.fused_wkv_wgate", "indexer.compressor.wkv", 0), - ("indexer.compressor.fused_wkv_wgate", "indexer.compressor.wgate", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -1473,14 +1509,6 @@ class DeepseekV4Model(nn.Module): # Pre-compute expert mapping ONCE. expert_mapping = self.get_expert_mapping() - # NVFP4 compressor/indexer scale params need special handling: - # wkv.input_scale (shape [1]) + wgate.input_scale (shape [1]) - # must be concatenated into fused_wkv_wgate.input_scale (shape [2]). - # The default stacking path fails because PerTensorScaleParameter's - # weight_loader asserts shape equality. - # We buffer them and load once both shards are available. - compressor_scale_buffer: dict[str, dict[int, torch.Tensor]] = {} - for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -1492,44 +1520,8 @@ class DeepseekV4Model(nn.Module): if is_pp_missing_parameter(name, self): break - if name not in params_dict: - # The stacked param doesn't exist — skip - # (e.g. indexer.compressor.fused_wkv_wgate on layers - # that don't have the full indexer structure) - break param = params_dict[name] weight_loader = param.weight_loader - - # NVFP4 scale params for stacked fused_wkv_wgate need - # special handling: each shard (wkv, wgate) has scale - # shape [1] or [head_dim, K], but the fused param has - # shape [2] or [2*head_dim, K]. The default stacking - # weight_loader can't handle this for PerTensorScale or - # ModelWeight scale params. Buffer and concatenate. - is_compressor_scale = ( - "fused_wkv_wgate" in name - and name.endswith(( - "input_scale", - "weight_scale", - "weight_scale_2", - )) - ) - if is_compressor_scale: - # Verify the fused param exists before buffering - if name not in params_dict: - print( - f"COMPRESSOR_SCALE_SKIP: {name} not in params_dict", - flush=True, - ) - break - if is_compressor_scale: - # Buffer the shard for later concatenation - if name not in compressor_scale_buffer: - compressor_scale_buffer[name] = {} - compressor_scale_buffer[name][shard_id] = loaded_weight - loaded_params.add(name) - break - weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) break @@ -1582,51 +1574,13 @@ class DeepseekV4Model(nn.Module): else: if is_pp_missing_parameter(name, self): continue - if name not in params_dict: - print(f"Skipping weight {name} (not in model params)", - flush=True) - continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) - try: - weight_loader(param, loaded_weight) - except (AssertionError, RuntimeError) as e: - print( - f"WEIGHT_LOAD_FAIL: name={name} " - f"param_shape={param.data.shape if hasattr(param, 'data') else '?'} " - f"loaded_shape={loaded_weight.shape} " - f"loaded_dtype={loaded_weight.dtype} " - f"error={e}", - flush=True, - ) - raise - - # Load buffered compressor/indexer scale params. - # These are NVFP4 quantization scales that need concatenation - # across shards (wkv=shard0, wgate=shard1) before loading. - for name, shards in compressor_scale_buffer.items(): - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - if len(shards) == 2: - # Concatenate shard 0 and shard 1 along dim 0. - # Scales may be 0-dim scalars (input_scale, weight_scale_2) - # or N-dim tensors (weight_scale); reshape scalars to 1-d. - s0, s1 = shards[0], shards[1] - if s0.ndim == 0: - s0 = s0.reshape(1) - if s1.ndim == 0: - s1 = s1.reshape(1) - stacked = torch.cat([s0, s1], dim=0) - else: - stacked = shards[0] - assert param.data.shape == stacked.shape, ( - f"Scale shape mismatch for {name}: " - f"param={param.data.shape} loaded={stacked.shape}" - ) - param.data.copy_(stacked) + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue return loaded_params @@ -1647,89 +1601,6 @@ class DeepseekV4Model(nn.Module): def finalize_mega_moe_weights(self) -> None: for layer in islice(self.layers, self.start_layer, self.end_layer): layer.ffn.finalize_mega_moe_weights() - # Initialize wo_a NVFP4 runner instead of quantizing to FP8 - attn = layer.attn - if hasattr(attn, 'wo_a') and attn.wo_a.weight.dtype == torch.bfloat16: - self._init_wo_a_nvfp4(attn) - - @staticmethod - def _init_wo_a_nvfp4(attn) -> None: - """Initialize CuTeDSL NVFP4 runner for wo_a. - - Replaces the old _quantize_wo_a_to_fp8 approach. Instead of - quantizing to FP8 and using DeepGEMM fp8_einsum (which crashes - on Blackwell), we quantize to NVFP4 and use our CuTeDSL kernel. - - wo_a is a grouped matmul (bmm) with n_local_groups groups. - Each group: (tokens, heads_per_group * head_dim) × (heads_per_group * head_dim, o_lora_rank) - """ - from cutedsl.wo_a_grouped_linear import CuTeDSLNvfp4WoA - - wo_a = attn.wo_a - weight_bf16 = wo_a.weight.data # (out_features, in_features) = (n_groups * o_lora_rank, heads_per_group * head_dim) - - n_local_groups = attn.n_local_groups - heads_per_group = attn.n_local_heads // n_local_groups - head_dim = attn.head_dim - o_lora_rank = attn.o_lora_rank - - runner = CuTeDSLNvfp4WoA( - n_local_groups=n_local_groups, - heads_per_group=heads_per_group, - head_dim=head_dim, - o_lora_rank=o_lora_rank, - max_num_tokens=8192, - device=weight_bf16.device, - ) - - # The weight is (n_groups * o_lora_rank, heads_per_group * head_dim) - # set_bf16_weight handles the 2D (dense) format - runner.set_bf16_weight(weight_bf16) - runner.finalize_weights() - - # Warmup: compute activation global scale from sample data - # This uses a representative random sample; the scale will be - # recomputed on the first real forward pass with actual data. - with torch.no_grad(): - sample = torch.randn( - 8, n_local_groups * heads_per_group, head_dim, - dtype=torch.bfloat16, device=weight_bf16.device, - ) * 2.0 - runner._ensure_initialized() - runner.compute_activation_global_scale(sample) - - # Store the runner on the attention module - attn._wo_a_nvfp4 = runner - - -@torch.compile(backend=current_platform.simple_compile_backend) -def hc_head( - hidden_states: torch.Tensor, - hc_fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - rms_norm_eps: float, - hc_eps: float, -) -> torch.Tensor: - hc_mult, hidden_size = hidden_states.shape[-2:] - outer_shape = hidden_states.shape[:-2] - hs_flat = hidden_states.view(-1, hc_mult, hidden_size) - num_tokens = hs_flat.shape[0] - out = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device - ) - torch.ops.vllm.hc_head_fused_kernel( - hs_flat, - hc_fn, - hc_scale, - hc_base, - out, - hidden_size, - rms_norm_eps, - hc_eps, - hc_mult, - ) - return out.view(*outer_shape, hidden_size) def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: @@ -1761,7 +1632,13 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: orig_to_new_suffix={ "head.weight": "lm_head.weight", "embed.weight": "embed_tokens.weight", - ".ffn.gate.bias": ".ffn.gate.e_score_correction_bias", + # Pre-MoE norm + gate are now owned by ``DeepseekV4MoE.norm_gate`` + # (see NormGatedLinear). + ".ffn_norm.weight": ".ffn.norm_gate.norm.weight", + ".ffn.gate.weight": ".ffn.norm_gate.gate.weight", + ".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias", + # Hash MoE table also moved off the inner gate. + ".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid", }, orig_to_new_substr={ ".attn.compressor.": ".attn.mla_attn.compressor.", @@ -1770,18 +1647,16 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: ) - - def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper: """Weight mapper for NVFP4 (ModelOpt) DeepSeek-V4 checkpoints. NVFP4 checkpoints use different key naming than the upstream MXFP4 format: + - ``self_attn`` prefix instead of ``attn`` + - ``mlp`` prefix instead of ``ffn`` - Expert weights: gate_proj/up_proj/down_proj (not w1/w3/w2) - Scales already have .weight_scale / .weight_scale_2 / .input_scale suffixes - - Shared expert uses down_proj (not w2) - - Self-attention uses .self_attn. prefix (same as checkpoint, renamed to .attn.) - - This is the mapper that should be used when quantization is modelopt_fp4. + - Compressor uses kv_proj/gate_proj (not wkv/wgate) + - o_a_proj is BF16 (no quantization scales) """ expert_rename_regex = { re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.", @@ -1789,82 +1664,83 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper: re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.", } - suffix_renames = { - # The NVFP4 checkpoint already uses lm_head / embed_tokens directly, - # no suffix renames needed (unlike the MXFP4 upstream format). - } - - # NOTE: specific renames MUST come before general ones (applied in order) - substr_renames = { - # === Indexer params (MUST come before .self_attn.compressor. - # so that indexer keys are captured before the compressor prefix - # rewrite moves them under mla_attn.compressor) === - # The checkpoint puts indexer under self_attn.compressor.indexer.* - # but the model has indexer at attn.indexer.* (sibling of compressor, - # NOT nested under it). - ".self_attn.compressor.indexer.q_b_proj.": ".attn.indexer.wq_b.", - ".self_attn.compressor.indexer.weights_proj.": ".attn.indexer.weights_proj.", - ".self_attn.compressor.indexer.kv_norm.": ".attn.indexer.k_norm.", - ".self_attn.compressor.indexer.kv_proj.": ".attn.indexer.compressor.wkv.", - ".self_attn.compressor.indexer.gate_proj.": ".attn.indexer.compressor.wgate.", - ".self_attn.compressor.indexer.position_bias": ".attn.indexer.compressor.ape", - # === Compressor (non-indexer) NVFP4 renames === - # Checkpoint uses kv_proj/gate_proj, model uses wkv/wgate - # (for stacking into fused_wkv_wgate). - "compressor.kv_proj.": "compressor.wkv.", - "compressor.gate_proj.": "compressor.wgate.", - "compressor.kv_norm.": "compressor.norm.", - "compressor.position_bias": "compressor.ape", - # === Attention compressor (MUST come after indexer renames - # so that remaining .self_attn.compressor. (non-indexer) keys - # become .attn.mla_attn.compressor.) === - ".self_attn.compressor.": ".attn.mla_attn.compressor.", - # === Attention projections (specific before .self_attn. → .attn.) === - ".self_attn.q_a_proj.": ".attn.wq_a.", - ".self_attn.kv_proj.": ".attn.wkv.", - ".self_attn.q_b_proj.": ".attn.wq_b.", - ".self_attn.o_a_proj.": ".attn.wo_a.", - ".self_attn.o_b_proj.": ".attn.wo_b.", - ".self_attn.q_a_norm.": ".attn.q_norm.", - ".self_attn.kv_norm.": ".attn.kv_norm.", - ".self_attn.sinks": ".attn.attn_sink", - # Shared expert projections (specific before .mlp. → .ffn.) - ".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.", - ".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.", - ".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.", - # General renames - ".mlp.": ".ffn.", - ".self_attn.": ".attn.", - # Layer norms (checkpoint uses input_layernorm / post_attention_layernorm, - # model uses attn_norm / ffn_norm) - "input_layernorm.": "attn_norm.", - "post_attention_layernorm.": "ffn_norm.", - # Per-layer HC params (checkpoint uses attn_hc / ffn_hc with dot, - # model uses hc_attn / hc_ffn with underscore) - ".attn_hc.fn": ".hc_attn_fn", - ".attn_hc.base": ".hc_attn_base", - ".attn_hc.scale": ".hc_attn_scale", - ".ffn_hc.fn": ".hc_ffn_fn", - ".ffn_hc.base": ".hc_ffn_base", - ".ffn_hc.scale": ".hc_ffn_scale", - # Top-level hc_head params (checkpoint uses hc_head.fn etc, - # model uses hc_head_fn etc) - "hc_head.fn": "hc_head_fn", - "hc_head.base": "hc_head_base", - "hc_head.scale": "hc_head_scale", - } - return WeightsMapper( orig_to_new_prefix={ "layers.": "model.layers.", - "embed_tokens.": "model.embed_tokens.", + "embed.": "model.embed.", "norm.": "model.norm.", "hc_head": "model.hc_head", "mtp.": "model.mtp.", }, orig_to_new_regex=expert_rename_regex, - orig_to_new_suffix=suffix_renames, - orig_to_new_substr=substr_renames, + # No suffix renames needed — NVFP4 checkpoint uses + # .weight_scale / .weight_scale_2 / .input_scale directly. + orig_to_new_suffix={ + "head.weight": "lm_head.weight", + "embed.weight": "embed_tokens.weight", + # Pre-MoE norm + gate are now owned by DeepseekV4MoE.norm_gate + ".ffn_norm.weight": ".ffn.norm_gate.norm.weight", + ".ffn.gate.weight": ".ffn.norm_gate.gate.weight", + ".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias", + ".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid", + }, + # Specific renames MUST come before general ones (applied in order). + orig_to_new_substr={ + # Indexer params (MUST come before .self_attn.compressor. + # so indexer keys are captured before the compressor prefix + # rewrite moves them under mla_attn.compressor). + ".self_attn.compressor.indexer.q_b_proj.": + ".attn.indexer.wq_b.", + ".self_attn.compressor.indexer.weights_proj.": + ".attn.indexer.weights_proj.", + ".self_attn.compressor.indexer.kv_norm.": + ".attn.indexer.k_norm.", + ".self_attn.compressor.indexer.kv_proj.": + ".attn.indexer.compressor.wkv.", + ".self_attn.compressor.indexer.gate_proj.": + ".attn.indexer.compressor.wgate.", + ".self_attn.compressor.indexer.position_bias": + ".attn.indexer.compressor.ape", + # Compressor (non-indexer) renames + "compressor.kv_proj.": "compressor.wkv.", + "compressor.gate_proj.": "compressor.wgate.", + "compressor.kv_norm.": "compressor.norm.", + "compressor.position_bias": "compressor.ape", + # Attention compressor (after indexer renames) + ".self_attn.compressor.": ".attn.mla_attn.compressor.", + # Attention projections (specific before .self_attn. → .attn.) + ".self_attn.q_a_proj.": ".attn.wq_a.", + ".self_attn.kv_proj.": ".attn.wkv.", + ".self_attn.q_b_proj.": ".attn.wq_b.", + ".self_attn.o_a_proj.": ".attn.wo_a.", + ".self_attn.o_b_proj.": ".attn.wo_b.", + ".self_attn.q_a_norm.": ".attn.q_norm.", + ".self_attn.kv_norm.": ".attn.kv_norm.", + ".self_attn.sinks": ".attn.attn_sink", + # Shared expert projections (specific before .mlp. → .ffn.) + ".mlp.shared_experts.gate_proj.": + ".ffn.shared_experts.w1.", + ".mlp.shared_experts.up_proj.": + ".ffn.shared_experts.w3.", + ".mlp.shared_experts.down_proj.": + ".ffn.shared_experts.down_proj.", + # General renames + ".mlp.": ".ffn.", + ".self_attn.": ".attn.", + # Layer norms + "input_layernorm.": "attn_norm.", + "post_attention_layernorm.": "ffn_norm.", + # HC params + ".attn_hc.fn": ".hc_attn_fn", + ".attn_hc.base": ".hc_attn_base", + ".attn_hc.scale": ".hc_attn_scale", + ".ffn_hc.fn": ".hc_ffn_fn", + ".ffn_hc.base": ".hc_ffn_base", + ".ffn_hc.scale": ".hc_ffn_scale", + "hc_head.fn": "hc_head_fn", + "hc_head.base": "hc_head_base", + "hc_head.scale": "hc_head_scale", + }, ) @@ -1879,21 +1755,18 @@ class DeepseekV4ForCausalLM(nn.Module, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config - + self.config = config + expert_dtype = getattr(config, "expert_dtype", "fp4") + quant_config = vllm_config.quant_config # Select weight mapper based on quantization method. # NVFP4 (modelopt_fp4) checkpoints use different key naming # than the default MXFP4 format. - quant_config = vllm_config.quant_config - if quant_config is not None and getattr(quant_config, "get_name", lambda: None)() == "modelopt_fp4": + if (quant_config is not None + and quant_config.get_name() == "modelopt_fp4"): self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper() - elif getattr(config, "expert_dtype", "fp4") != "fp4": - self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8") - else: - self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4") - self.config = config - expert_dtype = getattr(config, "expert_dtype", "fp4") - if expert_dtype != "fp4": - self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype) + elif expert_dtype != "fp4": + self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper( + expert_dtype) self.model = self.model_cls( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") diff --git a/vllm/patches/deepseek_v4_attention.py b/vllm/patches/deepseek_v4_attention.py index 5bc924eb..7ed9d59f 100644 --- a/vllm/patches/deepseek_v4_attention.py +++ b/vllm/patches/deepseek_v4_attention.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from transformers import DeepseekV2Config, DeepseekV3Config import vllm.envs as envs +from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.model_executor.layers.linear import ( ReplicatedLinear, ) @@ -46,14 +47,9 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor -from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.input_quant_fp8 import ( - QuantFP8, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, -) + from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import ( execute_in_parallel, @@ -186,19 +182,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): self.kv_norm = mla_modules.kv_norm self.wo_a = mla_modules.wo_a - # NVFP4 runner for wo_a — replaces DeepGEMM fp8_einsum. - # Initialized in DeepseekV4Model.finalize_mega_moe_weights() - # after wo_a BF16 weights are loaded. - self._wo_a_nvfp4 = None - - self._wo_a_act_quant = QuantFP8( - static=False, - group_shape=GroupShape(1, 128), - use_ue8m0=True, - ) - # Bypass packed-for-deepgemm path — we need FP32 scales (not packed - # INT32) so fp8_einsum can handle layout transform internally. - self._wo_a_act_quant.use_deep_gemm_supported = False self.wo_b = mla_modules.wo_b # Pick fp8_einsum recipe based on GPU arch: @@ -321,25 +304,29 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): ) return self.wo_b(z.flatten(1)) - # O projection: inverse RoPE + NVFP4 grouped GEMM + wo_b - # Using our CuTeDSL NVFP4 kernel instead of DeepGEMM fp8_einsum - if self._wo_a_nvfp4 is not None: - from cutedsl.inverse_rope import inverse_rope_bf16 - o_inv = inverse_rope_bf16( + # Detect if wo_a has FP8 weights (weight_scale_inv attribute). + # NVFP4 checkpoints leave wo_a as BF16 (no quantization scales), + # so we use inverse RoPE in BF16 + regular matmul instead of + # the FP8 einsum path (which crashes on Blackwell SM100). + has_fp8_weights = hasattr(self.wo_a, 'weight_scale_inv') + + if not has_fp8_weights: + # BF16 wo_a path: inverse RoPE in BF16, then regular matmul + o_inv = _apply_inv_rope_bf16( o, positions, self.rotary_emb.cos_sin_cache.to(torch.float32), nope_dim=self.nope_head_dim, rope_dim=self.rope_head_dim, ) - # Activation global scale is computed during init (finalize_mega_moe_weights) - z = self._wo_a_nvfp4(o_inv) + z, _ = self.wo_a(o_inv.reshape(num_tokens, -1)) + z = z.view(num_tokens, self.n_local_groups, self.o_lora_rank) return self.wo_b(z.flatten(1)) - # Fallback: original DeepGEMM path (for non-Blackwell or before init) + # FP8 wo_a path: fused inverse RoPE + FP8 quant + einsum (SM90 only) o_fp8, o_scale = fused_inv_rope_fp8_quant( o, positions, - self.rotary_emb.cos_sin_cache.to(torch.float32), + self.rotary_emb.cos_sin_cache, n_groups=self.n_local_groups, heads_per_group=self.n_local_heads // self.n_local_groups, nope_dim=self.nope_head_dim, @@ -384,14 +371,11 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): compressor = self.compressor def compressor_kv_score() -> torch.Tensor: - # Use forward() for quantized layers (NVFP4, FP8, etc.) - # — raw torch.mm doesn't work with packed/dequantized weights. - # MergedColumnParallelLinear with return_bias=False returns - # a tensor directly. - result = compressor.fused_wkv_wgate(hidden_states) - if isinstance(result, tuple): - result = result[0] - return result.to(torch.float32) + return torch.mm( + hidden_states, + compressor.fused_wkv_wgate.weight.T, + out_dtype=torch.float32, + ) aux_fns[0] = compressor_kv_score @@ -404,10 +388,11 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): return weights def indexer_compressor_kv_score() -> torch.Tensor: - result = indexer.compressor.fused_wkv_wgate(hidden_states) - if isinstance(result, tuple): - result = result[0] - return result.to(torch.float32) + return torch.mm( + hidden_states, + indexer.compressor.fused_wkv_wgate.weight.T, + out_dtype=torch.float32, + ) aux_fns[1] = indexer_weights_proj aux_fns[2] = indexer_compressor_kv_score @@ -567,12 +552,48 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): swa_kv_cache_2d, swa_metadata.slot_mapping, positions.to(torch.int64), - self.rotary_emb.cos_sin_cache.to(torch.float32), + self.rotary_emb.cos_sin_cache, self.eps, swa_metadata.block_size, ) +def _apply_inv_rope_bf16( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + nope_dim: int, + rope_dim: int, +) -> torch.Tensor: + """Apply inverse RoPE to attention output in BF16. + + Inverse RoPE is just RoPE with cos → cos, sin → -sin. + Uses GPT-J style (interleaved) rotary embedding. + """ + if rope_dim == 0 or o.numel() == 0: + return o + half_rot = rope_dim // 2 + o_f32 = o.to(torch.float32) + cache = cos_sin_cache.index_select(0, positions.to(torch.long)) + cos = cache[:, :half_rot].to(torch.float32) + sin = cache[:, half_rot : 2 * half_rot].to(torch.float32) + view_shape = (positions.shape[0], 1, half_rot) + cos = cos.view(view_shape) + sin = sin.view(view_shape) + rope = o_f32[..., nope_dim:] + y_even = rope[..., 0::2] + y_odd = rope[..., 1::2] + # Inverse: sin → -sin + rope_out = torch.stack( + (y_even * cos - y_odd * sin, y_odd * cos + y_even * sin), + dim=-1, + ).flatten(-2) + o_f32 = o_f32.clone() + o_f32[..., nope_dim:] = rope_out + return o_f32.to(o.dtype) + + +@eager_break_during_capture def deepseek_v4_attention( hidden_states: torch.Tensor, positions: torch.Tensor, @@ -1126,10 +1147,9 @@ class DeepseekV4Indexer(nn.Module): hidden_size, self.n_head, bias=False, - quant_config=quant_config, + quant_config=None, prefix=f"{prefix}.weights_proj", ) - self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.softmax_scale = self.head_dim**-0.5 self.scale_fmt = "ue8m0"