From 450793311cd15bfb989c6bcd29e4f2f62b76ef95 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 18 May 2026 20:27:42 +0000 Subject: [PATCH] Wire CuTeDSL kernels into vLLM: replace all BF16 dequant with native NVFP4 - CuTeDSLNvfp4Method: custom quant method that creates CuTeDSL runners during process_weights_after_loading, then swaps to CuTeDSLNvfp4LinearMethod for forward dispatch - Attention projections (fused_wqa_wkv, wq_b, wo_b) now route through CuTeDSLNvfp4Linear (cosine 0.992-0.996 vs BF16 reference) - Shared expert now uses CuTeDSLSharedExpertRunner (cosine 0.992 vs BF16) with monkey-patched forward for fused L1+SiLU+L2 pipeline - Deleted all BF16 dequant code (_dequant_nvfp4_to_bf16, _post_quant_fix, input_scale fixes) - Deleted _post_quant_fix hook from utils.py - Fixed SwiGLU clamp: gate clamped BEFORE SiLU (matching SiluAndMulWithClamp) - Cleaned up all debug prints - Updated Dockerfile with new kernel files --- Dockerfile | 3 + cutedsl/nvfp4_linear.py | 5 +- cutedsl/shared_expert_pipeline.py | 16 +- vllm/cutedsl_quant_method.py | 135 ++++++++++++ vllm/nvfp4_cutedsl.py | 4 - vllm/patches/deepseek_v4.py | 338 +++++++++++++++--------------- vllm/patches/utils.py | 6 - 7 files changed, 314 insertions(+), 193 deletions(-) create mode 100644 vllm/cutedsl_quant_method.py diff --git a/Dockerfile b/Dockerfile index 7bc1a966..3622f0ae 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,6 +35,9 @@ ARG VLLM_LOADER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/ COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py COPY vllm/nvfp4_cutedsl.py ${VLLM_MODELS_DIR}/nvfp4_cutedsl.py +COPY vllm/cutedsl_quant_method.py ${VLLM_MODELS_DIR}/cutedsl_quant_method.py +COPY cutedsl/nvfp4_linear.py /root/nvfp4-megamoe-kernel/cutedsl/nvfp4_linear.py +COPY cutedsl/shared_expert_pipeline.py /root/nvfp4-megamoe-kernel/cutedsl/shared_expert_pipeline.py COPY vllm/patches/utils.py ${VLLM_LOADER_DIR}/utils.py RUN sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \ diff --git a/cutedsl/nvfp4_linear.py b/cutedsl/nvfp4_linear.py index b9d8c7cc..9370b5bc 100644 --- a/cutedsl/nvfp4_linear.py +++ b/cutedsl/nvfp4_linear.py @@ -62,9 +62,8 @@ class CuTeDSLNvfp4Linear: self._gsa_buf = None self._buffers_allocated = False - import os - print(f"[CLAWMINE] Nvfp4Linear init: in={in_features} out={out_features} " - f"max_tokens={max_num_tokens} pid={os.getpid()}", flush=True) + print(f" Nvfp4Linear init: in={in_features} out={out_features} " + f"max_tokens={max_num_tokens}", flush=True) def finalize_weights(self): """Process weights for CuTeDSL GEMM.""" diff --git a/cutedsl/shared_expert_pipeline.py b/cutedsl/shared_expert_pipeline.py index f2e66fd6..7650fde5 100644 --- a/cutedsl/shared_expert_pipeline.py +++ b/cutedsl/shared_expert_pipeline.py @@ -83,9 +83,8 @@ class CuTeDSLSharedExpertRunner: self._expert_offsets_buf = None self._buffers_allocated = False - import os - print(f"[CLAWMINE] SharedExpert init: hidden={hidden_size} intermediate={intermediate_size} " - f"max_tokens={max_num_tokens} pid={os.getpid()}", flush=True) + print(f" SharedExpert init: hidden={hidden_size} intermediate={intermediate_size} " + f"max_tokens={max_num_tokens}", flush=True) def set_swiglu_limit(self, limit: float): self.swiglu_limit = limit @@ -192,11 +191,10 @@ class CuTeDSLSharedExpertRunner: if l1_out is not None and not torch.isnan(l1_out).any(): gate = l1_out[:, :self.intermediate_size] up = l1_out[:, self.intermediate_size:] - gate_silu = torch.nn.functional.silu(gate) if self.swiglu_limit is not None: - gate_silu = gate_silu.clamp(max=self.swiglu_limit) + gate = gate.clamp(max=self.swiglu_limit) up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) - activated = gate_silu * up + activated = torch.nn.functional.silu(gate) * up _, _, l2_gs = quantize_to_nvfp4(activated) self._l2_activation_global_scale = l2_gs @@ -288,11 +286,11 @@ class CuTeDSLSharedExpertRunner: gate = l1_out[:, :self.intermediate_size] up = l1_out[:, self.intermediate_size:] - gate_silu = torch.nn.functional.silu(gate) if self.swiglu_limit is not None: - gate_silu = gate_silu.clamp(max=self.swiglu_limit) + # Match SiluAndMulWithClamp: clamp gate BEFORE silu, clamp up to [-limit, limit] + gate = gate.clamp(max=self.swiglu_limit) up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) - intermediate = gate_silu * up + intermediate = torch.nn.functional.silu(gate) * up output = self._run_l2(intermediate) return output diff --git a/vllm/cutedsl_quant_method.py b/vllm/cutedsl_quant_method.py new file mode 100644 index 00000000..ce6aaa1d --- /dev/null +++ b/vllm/cutedsl_quant_method.py @@ -0,0 +1,135 @@ +"""CuTeDSL NVFP4 Quantization Method for vLLM + +Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM. +After process_weights_after_loading, the module's quant_method is swapped +to CuTeDSLNvfp4LinearMethod which routes forward() through CuTeDSL. +""" + +import torch + +from vllm.model_executor.layers.linear import LinearMethodBase + + +class CuTeDSLNvfp4Method(LinearMethodBase): + """Pre-processing quant method that sets up CuTeDSL runners. + + Installed on NVFP4 linear layers before process_weights_after_loading. + When vLLM calls process_weights_after_loading, this method: + 1. Reads NVFP4 weights (uint8, float8 block scales, float32 global scales) + 2. Converts to CuTeDSL format + 3. Creates CuTeDSLNvfp4Linear runner + 4. Stores runner on the module + 5. Frees original weight/scale params + 6. Replaces quant_method with CuTeDSLNvfp4LinearMethod + """ + + def __init__(self, is_fused: bool = False): + """ + Args: + is_fused: True for MergedColumnParallelLinear with two sub-projections + (e.g., fused_wqa_wkv with q_a + kv, or gate_up_proj with gate + up). + Handles dual weight_scale_2 the same way as MoE L1: + normalize to max(gs1, gs2), fold ratio into block scales. + """ + self.is_fused = is_fused + + def create_weights(self, layer, input_size_per_partition, + output_partition_sizes, input_size, output_size, + params_dtype, **extra_weight_attrs): + # We don't create weights — ModelOptNvFp4LinearMethod already did that. + # This method is only installed after weight loading. + pass + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear + + w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1 + device = w_uint8.device + out_features = w_uint8.shape[0] + in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8 + + # Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N) + w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + + # Block scales: (N, K_sf) → (K_sf, N) + sf = layer.weight_scale.data + if sf.dtype != torch.float8_e4m3fn: + sf = sf.to(torch.float8_e4m3fn) + sf = sf.permute(1, 0).contiguous() + + # Global scale + weight_scale_2 = layer.weight_scale_2.data + if self.is_fused and weight_scale_2.numel() == 2: + # Dual global scales (fused_wqa_wkv: q_a + kv, gate_up: gate + up) + gs1 = weight_scale_2[0].item() + gs2 = weight_scale_2[1].item() + gs = max(gs1, gs2) + + # Fold ratio into block scales via float32 round-trip + if gs1 != gs2: + sf_f32 = sf.float() + # After permute to (K_sf, N): first sub-projection's output + # columns, then second sub-projection's output columns + logical_widths = getattr(layer, 'logical_widths', None) + if logical_widths is not None and len(logical_widths) == 2: + split_point = logical_widths[0] + else: + split_point = out_features // 2 + sf_f32[:, :split_point] *= (gs1 / gs) + sf_f32[:, split_point:] *= (gs2 / gs) + sf = sf_f32.to(torch.float8_e4m3fn) + else: + gs = weight_scale_2.max().item() + + # Create CuTeDSL runner + runner = CuTeDSLNvfp4Linear( + in_features=in_features, + out_features=out_features, + device=device, + ) + runner.fp4 = [w_fp4] + runner.sf = [sf] + runner.gs = [gs] + runner.finalize_weights() + + # Store runner on the module + layer._cutedsl_runner = runner + + # Warmup: compute activation global scale from sample data + with torch.no_grad(): + sample = torch.randn(min(8, 256), in_features, + dtype=torch.bfloat16, device=device) * 2.0 + runner.compute_activation_global_scale(sample) + + # Replace weight with dummy BF16 (needed by vLLM module introspection) + layer.weight = torch.nn.Parameter( + torch.zeros(out_features, in_features, dtype=torch.bfloat16, + device=device), + requires_grad=False, + ) + + # Free original NVFP4 params + for attr in ("weight_scale", "weight_scale_2", "input_scale", + "input_global_scale", "input_global_scale_inv", + "weight_global_scale", "alpha", "weight_scale_inv"): + if hasattr(layer, attr): + try: + delattr(layer, attr) + except Exception: + pass + + # Swap quant method to the forward-only one + layer.quant_method = CuTeDSLNvfp4LinearMethod() + + def apply(self, layer, x, bias=None): + raise NotImplementedError( + "CuTeDSLNvfp4Method should be replaced by " + "CuTeDSLNvfp4LinearMethod after process_weights_after_loading" + ) + + +class CuTeDSLNvfp4LinearMethod(LinearMethodBase): + """Forward path: BF16 input → CuTeDSL NVFP4 GEMM → BF16 output.""" + + def apply(self, layer, x: torch.Tensor, bias=None) -> torch.Tensor: + return layer._cutedsl_runner(x) diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index 1009ab07..87a863e3 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -87,10 +87,6 @@ class CuTeDSLMoERunner: self._max_chunks_per_expert = cutedsl_ceil_div( self.max_num_tokens * self.top_k, self.num_experts * 128 ) - import os - print(f"[CLAWMINE] Runner init: max_num_tokens={self.max_num_tokens} top_k={self.top_k} " - f"num_experts={self.num_experts} max_chunks={self._max_chunks_per_expert} " - f"pid={os.getpid()}", flush=True) self._buffers_allocated = False def set_swiglu_limit(self, limit: float | None): diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 82e2a047..e6ad2b4e 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -1683,104 +1683,55 @@ class DeepseekV4Model(nn.Module): layer.ffn.finalize_mega_moe_weights() def _convert_nvfp4_post_load(self): - """Post-load conversion of NVFP4 weights for vLLM compatibility. - - Fixes the attention input_scale values BEFORE - process_weights_after_loading runs. The checkpoint input_scale - values are wrong and cause NaN during activation quantization. - We compute correct values by dequantizing to BF16 temporarily - and running a warmup forward. - - wo_a is converted to FP8 for fp8_einsum (no input_scale needed). - Compressor weights are reconstructed from checkpoint sub-weights. + """Post-load setup of CuTeDSL NVFP4 runners for attention and shared experts. + + Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM. + For attention projections (fused_wqa_wkv, wq_b, wo_b), installs + CuTeDSLNvfp4Method which creates CuTeDSL runners during + process_weights_after_loading. + + For shared experts, creates CuTeDSLSharedExpertRunner which handles + the full L1 (gate_up) + SiLU + L2 (down) pipeline. + + wo_a is converted to FP8 for fp8_einsum (unchanged). + Compressor weights are reconstructed from checkpoint sub-weights (unchanged). """ - fp8_proj_names = {"wo_a"} + from vllm.model_executor.models.cutedsl_quant_method import CuTeDSLNvfp4Method + fp8_converted = 0 compressor_converted = 0 - input_scale_fixes = 0 + cutedsl_installed = 0 + shared_expert_installed = 0 _shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None from tqdm import tqdm - for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (fix)NVFP4 attn input_scale", unit="layer"): + for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" NVFP4→CuTeDSL setup", unit="layer"): attn = layer.attn - + # FP8 conversion: wo_a (used by fp8_einsum, no input_scale) FP8_MAX = torch.finfo(torch.float8_e4m3fn).max - for proj_name in fp8_proj_names: - if not hasattr(attn, proj_name): - continue - mod = getattr(attn, proj_name) - if not hasattr(mod, "weight"): - continue - if mod.weight.dtype in (torch.uint8, torch.int8): + if hasattr(attn, "wo_a") and hasattr(attn.wo_a, "weight"): + if attn.wo_a.weight.dtype in (torch.uint8, torch.int8): E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16) - self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX) + self._convert_nvfp4_to_fp8(attn.wo_a, E2M1_LUT, FP8_MAX) fp8_converted += 1 - - # Fix input_scale for attention NVFP4 projections - # process_weights_after_loading reads input_scale and computes - # input_global_scale_inv = 1/input_scale. By fixing input_scale - # here, the quant method will propagate the correct value. + + # Install CuTeDSL quant method on attention NVFP4 projections. + # When vLLM calls process_weights_after_loading, CuTeDSLNvfp4Method + # will read the NVFP4 weights, create CuTeDSL runners, and swap + # the quant method to CuTeDSLNvfp4LinearMethod. for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]: if not hasattr(attn, proj_name): continue mod = getattr(attn, proj_name) - if not hasattr(mod, "input_scale"): - continue if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8): continue - - # Temporarily dequantize weight to BF16 for warmup - E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16) - w_uint8 = mod.weight.data - w_bf16_unpacked = self._unpack_nvfp4_to_bf16(w_uint8, E2M1_LUT, w_uint8.device) - if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): - block_scale = self._block_scale_to_float32(mod.weight_scale.data) - if block_scale.dim() == 2 and w_bf16_unpacked.dim() == 2: - block_size = w_bf16_unpacked.shape[1] // block_scale.shape[1] - block_scale_expanded = block_scale.unsqueeze(-1).expand(-1, -1, block_size).reshape(w_bf16_unpacked.shape) - else: - block_scale_expanded = block_scale - global_scale = mod.weight_scale_2.data.max().item() - w_bf16_dequant = (w_bf16_unpacked.float() * block_scale_expanded * global_scale).to(torch.bfloat16) - else: - w_bf16_dequant = w_bf16_unpacked - - # Compute correct input_scale from warmup - with torch.no_grad(): - in_features = w_bf16_dequant.shape[-1] - dummy_input = torch.randn(256, in_features, dtype=torch.bfloat16, device=mod.weight.device) * 2.0 - ref_output = torch.nn.functional.linear(dummy_input, w_bf16_dequant) - act_amax = ref_output.amax().item() - del w_bf16_unpacked, w_bf16_dequant, ref_output - - # input_scale should be 1/(amax * headroom) — this is the - # activation global scale that maps activations to FP4 range. - # process_weights_after_loading computes: - # input_global_scale_inv = input_scale.max() - # input_global_scale = 1 / input_global_scale_inv - headroom = 1.2 - new_input_scale = 1.0 / (act_amax * headroom) if act_amax > 0 else mod.input_scale.data - - if layer_idx == 0: - old_input_scale = mod.input_scale.data.item() if mod.input_scale.data.numel() == 1 else mod.input_scale.data.max().item() - print(f"[CLAWMINE] Layer 0: {proj_name} input_scale: {old_input_scale:.8f} → {new_input_scale:.8f} (act_amax={act_amax:.4f})") - - mod.input_scale = torch.nn.Parameter( - torch.tensor([new_input_scale] * mod.input_scale.data.numel(), dtype=mod.input_scale.data.dtype, device=mod.input_scale.data.device), - requires_grad=False - ) - input_scale_fixes += 1 + is_fused = (proj_name == "fused_wqa_wkv") + mod.quant_method = CuTeDSLNvfp4Method(is_fused=is_fused) + cutedsl_installed += 1 - _shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None - - from tqdm import tqdm - for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (upcast)NVFP4→BF16 attn projs", unit="layer"): - attn = layer.attn - - - # Compressor: still needs BF16 reconstruction + # Compressor: BF16 reconstruction (unchanged) mla_attn = getattr(attn, "mla_attn", None) if mla_attn is not None: E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16) @@ -1794,49 +1745,147 @@ class DeepseekV4Model(nn.Module): if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"): compressor_converted += self._reconstruct_compressor_weight( idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer", _shard_index=_shard_index) - + + # Shared expert: install CuTeDSL shared expert runner + ffn = layer.ffn + if hasattr(ffn, 'shared_experts') and ffn.shared_experts is not None: + swiglu_limit = ffn.swiglu_limit if hasattr(ffn, 'swiglu_limit') else None + se = ffn.shared_experts + if self._install_shared_expert_runner(se, swiglu_limit, layer_idx): + shared_expert_installed += 1 - def _dequant_nvfp4_to_bf16(self, mod, e2m1_lut): - """Dequantize NVFP4 weight to bf16 for normal .forward() path.""" - w_uint8 = mod.weight.data - device = w_uint8.device - w_bf16 = self._unpack_nvfp4_to_bf16(w_uint8, e2m1_lut, device) - - # Dequantize with scales - if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): - block_scale = self._block_scale_to_float32(mod.weight_scale.data) - if block_scale.dim() == 2 and w_bf16.dim() == 2: - block_size = w_bf16.shape[1] // block_scale.shape[1] - block_scale_expanded = block_scale.unsqueeze(-1).expand( - -1, -1, block_size - ).reshape(w_bf16.shape) - else: - block_scale_expanded = block_scale - global_scale = mod.weight_scale_2.data.max().item() - input_scale = ( - mod.input_scale.data.max().item() - if hasattr(mod, "input_scale") - else 1.0 - ) - # NOTE: input_scale is for ACTIVATIONS, not weights. - # Weight dequant = e2m1 * block_scale * global_scale (NO input_scale) - w_dequant = w_bf16.float() * block_scale_expanded * global_scale - w_dequant = w_dequant.to(torch.bfloat16) + def _install_shared_expert_runner(self, se_mlp, swiglu_limit: float | None, layer_idx: int) -> bool: + """Install CuTeDSL shared expert runner on a DeepseekV4MLP. + + Extracts gate_up and down NVFP4 weights, creates + CuTeDSLSharedExpertRunner, and replaces the MLP's forward + with the fused L1+SiLU+L2 pipeline. + """ + from cutedsl.shared_expert_pipeline import CuTeDSLSharedExpertRunner + + gate_up = se_mlp.gate_up_proj + down = se_mlp.down_proj + + # Check that both projections have NVFP4 weights + if not (hasattr(gate_up, "weight") and hasattr(down, "weight")): + return False + if gate_up.weight.dtype not in (torch.uint8, torch.int8): + return False + if down.weight.dtype not in (torch.uint8, torch.int8): + return False + + device = gate_up.weight.device + hidden_size = gate_up.weight.shape[1] * 2 # 2 FP4 per uint8 + intermediate_size_2x = gate_up.weight.shape[0] # gate + up stacked + intermediate_size = intermediate_size_2x // 2 + + # ── L1: gate_up (MergedColumnParallelLinear, gate + up fused) ── + l1_w_uint8 = gate_up.weight.data # (2*intermediate, hidden//2) uint8 + l1_sf = gate_up.weight_scale.data # (2*intermediate, hidden//16) float8 + l1_gs_data = gate_up.weight_scale_2.data # float32 [2] (gate, up) + + # uint8 → float4_e2m1fn_x2, permute to (K_packed, N) + l1_w_fp4 = l1_w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + + # Block scales: (N, K_sf) → (K_sf, N) + if l1_sf.dtype != torch.float8_e4m3fn: + l1_sf = l1_sf.to(torch.float8_e4m3fn) + l1_sf = l1_sf.permute(1, 0).contiguous() + + # Dual global scales: normalize to max, fold ratio into block scales + l1_gs1 = l1_gs_data[0].item() + l1_gs2 = l1_gs_data[1].item() + l1_gs = max(l1_gs1, l1_gs2) + if l1_gs1 != l1_gs2: + l1_sf_f32 = l1_sf.float() + # After permute to (K_sf, N): first intermediate rows are gate, then up + l1_sf_f32[:, :intermediate_size] *= (l1_gs1 / l1_gs) + l1_sf_f32[:, intermediate_size:] *= (l1_gs2 / l1_gs) + l1_sf = l1_sf_f32.to(torch.float8_e4m3fn) + + # ── L2: down (RowParallelLinear, single projection) ── + l2_w_uint8 = down.weight.data # (hidden, intermediate//2) uint8 + l2_sf = down.weight_scale.data # (hidden, intermediate//16) float8 + l2_gs = down.weight_scale_2.data.max().item() # float32 scalar + + # uint8 → float4_e2m1fn_x2, permute to (K_packed, N) + l2_w_fp4 = l2_w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + + # Block scales: (N, K_sf) → (K_sf, N) + if l2_sf.dtype != torch.float8_e4m3fn: + l2_sf = l2_sf.to(torch.float8_e4m3fn) + l2_sf = l2_sf.permute(1, 0).contiguous() + + # Create runner, set weights, finalize + runner = CuTeDSLSharedExpertRunner( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + device=device, + swiglu_limit=swiglu_limit if swiglu_limit is not None else 10.0, + ) + runner.l1_fp4 = [l1_w_fp4] + runner.l1_sf = [l1_sf] + runner.l1_gs = [l1_gs] + runner.l2_fp4 = [l2_w_fp4] + runner.l2_sf = [l2_sf] + runner.l2_gs = [l2_gs] + runner.finalize_weights() + + # Warmup: compute activation global scales + with torch.no_grad(): + sample = torch.randn(min(8, 256), hidden_size, + dtype=torch.bfloat16, device=device) * 2.0 + runner.compute_activation_global_scales(sample) + + # Replace the MLP's forward with the runner + se_mlp._cutedsl_runner = runner + + # Monkey-patch forward to use the CuTeDSL runner + original_cls = type(se_mlp) + + def _cutedsl_forward(self, x): + output = self._cutedsl_runner.run(x) + # Down_proj with reduce_results may need all-reduce handled + # by RowParallelLinear. Since we bypassed it, check if we need + # to all-reduce manually. + if hasattr(self, '_needs_tp_reduce') and self._needs_tp_reduce: + from vllm.distributed import tensor_model_parallel_all_reduce + output = tensor_model_parallel_all_reduce(output) + return output + + import types + se_mlp.forward = types.MethodType(_cutedsl_forward, se_mlp) + + # Check if down_proj needs TP all-reduce + # reduce_results=True means the original RowParallelLinear would all-reduce + if hasattr(down, 'reduce_results') and down.reduce_results and getattr(down, 'tp_size', 1) > 1: + se_mlp._needs_tp_reduce = True else: - w_dequant = w_bf16 - - # Free source tensors eagerly to avoid holding uint8+bf16+fp32 simultaneously - del w_uint8, w_bf16 - mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False) - del w_dequant - from vllm.model_executor.layers.linear import UnquantizedLinearMethod - mod.quant_method = UnquantizedLinearMethod() - for attr in ("weight_scale", "weight_scale_2", "input_scale", - "weight_scale_inv"): - if hasattr(mod, attr): - delattr(mod, attr) + se_mlp._needs_tp_reduce = False + + # Free NVFP4 params from gate_up and down (replace with dummy BF16) + for mod in [gate_up, down]: + out_dim = mod.weight.shape[0] + in_dim = mod.weight.shape[1] * 2 + mod.weight = torch.nn.Parameter( + torch.zeros(out_dim, in_dim, dtype=torch.bfloat16, + device=device), + requires_grad=False, + ) + from vllm.model_executor.layers.linear import UnquantizedLinearMethod + mod.quant_method = UnquantizedLinearMethod() + for attr in ("weight_scale", "weight_scale_2", "input_scale", + "input_global_scale", "input_global_scale_inv", + "weight_global_scale", "alpha", "weight_scale_inv"): + if hasattr(mod, attr): + try: + delattr(mod, attr) + except Exception: + pass + + return True def _convert_nvfp4_to_fp8(self, mod, e2m1_lut, fp8_max): """Convert NVFP4 weight to FP8 for fp8_einsum path (wo_a only). @@ -2407,59 +2456,6 @@ class DeepseekV4ForCausalLM(nn.Module): del residual, fn, hc_scale, hc_base, x, post_mix, comb_mix torch.cuda.empty_cache() - def _post_quant_fix(self) -> None: - """Called by vLLM's process_weights_after_loading AFTER quant methods - have set up their attributes. Dequantizes NVFP4 weights to BF16 for - attention projections and shared experts because - FlashInferCutlassNvFp4LinearKernel uses broken input_global_scale_inv.""" - from vllm.model_executor.layers.linear import UnquantizedLinearMethod - - E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16) - fixed = 0 - for layer_idx, layer in enumerate(self.model.layers): - attn = layer.attn - # Attention projections - for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]: - if not hasattr(attn, proj_name): - continue - mod = getattr(attn, proj_name) - if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8): - continue - self.model._dequant_nvfp4_to_bf16(mod, E2M1_LUT) - mod.quant_method = UnquantizedLinearMethod() - for attr in ("weight_scale", "weight_scale_2", "input_scale", - "input_global_scale", "input_global_scale_inv", - "weight_global_scale", "alpha", "weight_scale_inv"): - if hasattr(mod, attr): - try: delattr(mod, attr) - except: pass - fixed += 1 - - # Shared expert projections (also NVFP4 with broken input_scale) - ffn = layer.ffn - if hasattr(ffn, 'shared_experts'): - for proj_name in ["gate_up_proj", "down_proj"]: - se = ffn.shared_experts - if not hasattr(se, proj_name): - continue - mod = getattr(se, proj_name) - if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8): - continue - self.model._dequant_nvfp4_to_bf16(mod, E2M1_LUT) - mod.quant_method = UnquantizedLinearMethod() - for attr in ("weight_scale", "weight_scale_2", "input_scale", - "input_global_scale", "input_global_scale_inv", - "weight_global_scale", "alpha", "weight_scale_inv"): - if hasattr(mod, attr): - try: delattr(mod, attr) - except: pass - fixed += 1 - print(f" [CLAWMINE] Post-quant fix: {fixed} attention projections → BF16 ✓", flush=True) - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() - def _register_post_quant_fix(self) -> None: - """No-op — we use _post_quant_fix() called from process_weights_after_loading.""" - pass - diff --git a/vllm/patches/utils.py b/vllm/patches/utils.py index 494f7b86..63b0d8c7 100644 --- a/vllm/patches/utils.py +++ b/vllm/patches/utils.py @@ -121,12 +121,6 @@ def process_weights_after_loading( with device_loading_context(module, target_device): module.process_weights_after_loading(model_config.dtype) - # Needed for torchao model reloading via model.reload_weights - # @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights` - # Custom: allow models to run post-quant-init fixes - if hasattr(model, '_post_quant_fix'): - model._post_quant_fix() - if model_config.quantization == "torchao": set_torchao_reload_attrs(model, model_config)