Wire CuTeDSL kernels into vLLM: replace all BF16 dequant with native NVFP4

- CuTeDSLNvfp4Method: custom quant method that creates CuTeDSL runners
  during process_weights_after_loading, then swaps to CuTeDSLNvfp4LinearMethod
  for forward dispatch
- Attention projections (fused_wqa_wkv, wq_b, wo_b) now route through
  CuTeDSLNvfp4Linear (cosine 0.992-0.996 vs BF16 reference)
- Shared expert now uses CuTeDSLSharedExpertRunner (cosine 0.992 vs BF16)
  with monkey-patched forward for fused L1+SiLU+L2 pipeline
- Deleted all BF16 dequant code (_dequant_nvfp4_to_bf16, _post_quant_fix,
  input_scale fixes)
- Deleted _post_quant_fix hook from utils.py
- Fixed SwiGLU clamp: gate clamped BEFORE SiLU (matching SiluAndMulWithClamp)
- Cleaned up all debug prints
- Updated Dockerfile with new kernel files
This commit is contained in:
2026-05-18 20:27:42 +00:00
parent 6ce6a47be9
commit 450793311c
7 changed files with 314 additions and 193 deletions

View File

@@ -35,6 +35,9 @@ ARG VLLM_LOADER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/
COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py
COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py
COPY vllm/nvfp4_cutedsl.py ${VLLM_MODELS_DIR}/nvfp4_cutedsl.py
COPY vllm/cutedsl_quant_method.py ${VLLM_MODELS_DIR}/cutedsl_quant_method.py
COPY cutedsl/nvfp4_linear.py /root/nvfp4-megamoe-kernel/cutedsl/nvfp4_linear.py
COPY cutedsl/shared_expert_pipeline.py /root/nvfp4-megamoe-kernel/cutedsl/shared_expert_pipeline.py
COPY vllm/patches/utils.py ${VLLM_LOADER_DIR}/utils.py
RUN sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \

View File

@@ -62,9 +62,8 @@ class CuTeDSLNvfp4Linear:
self._gsa_buf = None
self._buffers_allocated = False
import os
print(f"[CLAWMINE] Nvfp4Linear init: in={in_features} out={out_features} "
f"max_tokens={max_num_tokens} pid={os.getpid()}", flush=True)
print(f" Nvfp4Linear init: in={in_features} out={out_features} "
f"max_tokens={max_num_tokens}", flush=True)
def finalize_weights(self):
"""Process weights for CuTeDSL GEMM."""

View File

@@ -83,9 +83,8 @@ class CuTeDSLSharedExpertRunner:
self._expert_offsets_buf = None
self._buffers_allocated = False
import os
print(f"[CLAWMINE] SharedExpert init: hidden={hidden_size} intermediate={intermediate_size} "
f"max_tokens={max_num_tokens} pid={os.getpid()}", flush=True)
print(f" SharedExpert init: hidden={hidden_size} intermediate={intermediate_size} "
f"max_tokens={max_num_tokens}", flush=True)
def set_swiglu_limit(self, limit: float):
self.swiglu_limit = limit
@@ -192,11 +191,10 @@ class CuTeDSLSharedExpertRunner:
if l1_out is not None and not torch.isnan(l1_out).any():
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]
gate_silu = torch.nn.functional.silu(gate)
if self.swiglu_limit is not None:
gate_silu = gate_silu.clamp(max=self.swiglu_limit)
gate = gate.clamp(max=self.swiglu_limit)
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
activated = gate_silu * up
activated = torch.nn.functional.silu(gate) * up
_, _, l2_gs = quantize_to_nvfp4(activated)
self._l2_activation_global_scale = l2_gs
@@ -288,11 +286,11 @@ class CuTeDSLSharedExpertRunner:
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]
gate_silu = torch.nn.functional.silu(gate)
if self.swiglu_limit is not None:
gate_silu = gate_silu.clamp(max=self.swiglu_limit)
# Match SiluAndMulWithClamp: clamp gate BEFORE silu, clamp up to [-limit, limit]
gate = gate.clamp(max=self.swiglu_limit)
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
intermediate = gate_silu * up
intermediate = torch.nn.functional.silu(gate) * up
output = self._run_l2(intermediate)
return output

View File

@@ -0,0 +1,135 @@
"""CuTeDSL NVFP4 Quantization Method for vLLM
Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM.
After process_weights_after_loading, the module's quant_method is swapped
to CuTeDSLNvfp4LinearMethod which routes forward() through CuTeDSL.
"""
import torch
from vllm.model_executor.layers.linear import LinearMethodBase
class CuTeDSLNvfp4Method(LinearMethodBase):
"""Pre-processing quant method that sets up CuTeDSL runners.
Installed on NVFP4 linear layers before process_weights_after_loading.
When vLLM calls process_weights_after_loading, this method:
1. Reads NVFP4 weights (uint8, float8 block scales, float32 global scales)
2. Converts to CuTeDSL format
3. Creates CuTeDSLNvfp4Linear runner
4. Stores runner on the module
5. Frees original weight/scale params
6. Replaces quant_method with CuTeDSLNvfp4LinearMethod
"""
def __init__(self, is_fused: bool = False):
"""
Args:
is_fused: True for MergedColumnParallelLinear with two sub-projections
(e.g., fused_wqa_wkv with q_a + kv, or gate_up_proj with gate + up).
Handles dual weight_scale_2 the same way as MoE L1:
normalize to max(gs1, gs2), fold ratio into block scales.
"""
self.is_fused = is_fused
def create_weights(self, layer, input_size_per_partition,
output_partition_sizes, input_size, output_size,
params_dtype, **extra_weight_attrs):
# We don't create weights — ModelOptNvFp4LinearMethod already did that.
# This method is only installed after weight loading.
pass
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1
device = w_uint8.device
out_features = w_uint8.shape[0]
in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8
# Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N)
w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# Block scales: (N, K_sf) → (K_sf, N)
sf = layer.weight_scale.data
if sf.dtype != torch.float8_e4m3fn:
sf = sf.to(torch.float8_e4m3fn)
sf = sf.permute(1, 0).contiguous()
# Global scale
weight_scale_2 = layer.weight_scale_2.data
if self.is_fused and weight_scale_2.numel() == 2:
# Dual global scales (fused_wqa_wkv: q_a + kv, gate_up: gate + up)
gs1 = weight_scale_2[0].item()
gs2 = weight_scale_2[1].item()
gs = max(gs1, gs2)
# Fold ratio into block scales via float32 round-trip
if gs1 != gs2:
sf_f32 = sf.float()
# After permute to (K_sf, N): first sub-projection's output
# columns, then second sub-projection's output columns
logical_widths = getattr(layer, 'logical_widths', None)
if logical_widths is not None and len(logical_widths) == 2:
split_point = logical_widths[0]
else:
split_point = out_features // 2
sf_f32[:, :split_point] *= (gs1 / gs)
sf_f32[:, split_point:] *= (gs2 / gs)
sf = sf_f32.to(torch.float8_e4m3fn)
else:
gs = weight_scale_2.max().item()
# Create CuTeDSL runner
runner = CuTeDSLNvfp4Linear(
in_features=in_features,
out_features=out_features,
device=device,
)
runner.fp4 = [w_fp4]
runner.sf = [sf]
runner.gs = [gs]
runner.finalize_weights()
# Store runner on the module
layer._cutedsl_runner = runner
# Warmup: compute activation global scale from sample data
with torch.no_grad():
sample = torch.randn(min(8, 256), in_features,
dtype=torch.bfloat16, device=device) * 2.0
runner.compute_activation_global_scale(sample)
# Replace weight with dummy BF16 (needed by vLLM module introspection)
layer.weight = torch.nn.Parameter(
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
device=device),
requires_grad=False,
)
# Free original NVFP4 params
for attr in ("weight_scale", "weight_scale_2", "input_scale",
"input_global_scale", "input_global_scale_inv",
"weight_global_scale", "alpha", "weight_scale_inv"):
if hasattr(layer, attr):
try:
delattr(layer, attr)
except Exception:
pass
# Swap quant method to the forward-only one
layer.quant_method = CuTeDSLNvfp4LinearMethod()
def apply(self, layer, x, bias=None):
raise NotImplementedError(
"CuTeDSLNvfp4Method should be replaced by "
"CuTeDSLNvfp4LinearMethod after process_weights_after_loading"
)
class CuTeDSLNvfp4LinearMethod(LinearMethodBase):
"""Forward path: BF16 input → CuTeDSL NVFP4 GEMM → BF16 output."""
def apply(self, layer, x: torch.Tensor, bias=None) -> torch.Tensor:
return layer._cutedsl_runner(x)

View File

@@ -87,10 +87,6 @@ class CuTeDSLMoERunner:
self._max_chunks_per_expert = cutedsl_ceil_div(
self.max_num_tokens * self.top_k, self.num_experts * 128
)
import os
print(f"[CLAWMINE] Runner init: max_num_tokens={self.max_num_tokens} top_k={self.top_k} "
f"num_experts={self.num_experts} max_chunks={self._max_chunks_per_expert} "
f"pid={os.getpid()}", flush=True)
self._buffers_allocated = False
def set_swiglu_limit(self, limit: float | None):

View File

@@ -1683,104 +1683,55 @@ class DeepseekV4Model(nn.Module):
layer.ffn.finalize_mega_moe_weights()
def _convert_nvfp4_post_load(self):
"""Post-load conversion of NVFP4 weights for vLLM compatibility.
Fixes the attention input_scale values BEFORE
process_weights_after_loading runs. The checkpoint input_scale
values are wrong and cause NaN during activation quantization.
We compute correct values by dequantizing to BF16 temporarily
and running a warmup forward.
wo_a is converted to FP8 for fp8_einsum (no input_scale needed).
Compressor weights are reconstructed from checkpoint sub-weights.
"""Post-load setup of CuTeDSL NVFP4 runners for attention and shared experts.
Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM.
For attention projections (fused_wqa_wkv, wq_b, wo_b), installs
CuTeDSLNvfp4Method which creates CuTeDSL runners during
process_weights_after_loading.
For shared experts, creates CuTeDSLSharedExpertRunner which handles
the full L1 (gate_up) + SiLU + L2 (down) pipeline.
wo_a is converted to FP8 for fp8_einsum (unchanged).
Compressor weights are reconstructed from checkpoint sub-weights (unchanged).
"""
fp8_proj_names = {"wo_a"}
from vllm.model_executor.models.cutedsl_quant_method import CuTeDSLNvfp4Method
fp8_converted = 0
compressor_converted = 0
input_scale_fixes = 0
cutedsl_installed = 0
shared_expert_installed = 0
_shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None
from tqdm import tqdm
for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (fix)NVFP4 attn input_scale", unit="layer"):
for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" NVFP4→CuTeDSL setup", unit="layer"):
attn = layer.attn
# FP8 conversion: wo_a (used by fp8_einsum, no input_scale)
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
for proj_name in fp8_proj_names:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight"):
continue
if mod.weight.dtype in (torch.uint8, torch.int8):
if hasattr(attn, "wo_a") and hasattr(attn.wo_a, "weight"):
if attn.wo_a.weight.dtype in (torch.uint8, torch.int8):
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX)
self._convert_nvfp4_to_fp8(attn.wo_a, E2M1_LUT, FP8_MAX)
fp8_converted += 1
# Fix input_scale for attention NVFP4 projections
# process_weights_after_loading reads input_scale and computes
# input_global_scale_inv = 1/input_scale. By fixing input_scale
# here, the quant method will propagate the correct value.
# Install CuTeDSL quant method on attention NVFP4 projections.
# When vLLM calls process_weights_after_loading, CuTeDSLNvfp4Method
# will read the NVFP4 weights, create CuTeDSL runners, and swap
# the quant method to CuTeDSLNvfp4LinearMethod.
for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "input_scale"):
continue
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
continue
# Temporarily dequantize weight to BF16 for warmup
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
w_uint8 = mod.weight.data
w_bf16_unpacked = self._unpack_nvfp4_to_bf16(w_uint8, E2M1_LUT, w_uint8.device)
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
block_scale = self._block_scale_to_float32(mod.weight_scale.data)
if block_scale.dim() == 2 and w_bf16_unpacked.dim() == 2:
block_size = w_bf16_unpacked.shape[1] // block_scale.shape[1]
block_scale_expanded = block_scale.unsqueeze(-1).expand(-1, -1, block_size).reshape(w_bf16_unpacked.shape)
else:
block_scale_expanded = block_scale
global_scale = mod.weight_scale_2.data.max().item()
w_bf16_dequant = (w_bf16_unpacked.float() * block_scale_expanded * global_scale).to(torch.bfloat16)
else:
w_bf16_dequant = w_bf16_unpacked
# Compute correct input_scale from warmup
with torch.no_grad():
in_features = w_bf16_dequant.shape[-1]
dummy_input = torch.randn(256, in_features, dtype=torch.bfloat16, device=mod.weight.device) * 2.0
ref_output = torch.nn.functional.linear(dummy_input, w_bf16_dequant)
act_amax = ref_output.amax().item()
del w_bf16_unpacked, w_bf16_dequant, ref_output
# input_scale should be 1/(amax * headroom) — this is the
# activation global scale that maps activations to FP4 range.
# process_weights_after_loading computes:
# input_global_scale_inv = input_scale.max()
# input_global_scale = 1 / input_global_scale_inv
headroom = 1.2
new_input_scale = 1.0 / (act_amax * headroom) if act_amax > 0 else mod.input_scale.data
if layer_idx == 0:
old_input_scale = mod.input_scale.data.item() if mod.input_scale.data.numel() == 1 else mod.input_scale.data.max().item()
print(f"[CLAWMINE] Layer 0: {proj_name} input_scale: {old_input_scale:.8f}{new_input_scale:.8f} (act_amax={act_amax:.4f})")
mod.input_scale = torch.nn.Parameter(
torch.tensor([new_input_scale] * mod.input_scale.data.numel(), dtype=mod.input_scale.data.dtype, device=mod.input_scale.data.device),
requires_grad=False
)
input_scale_fixes += 1
is_fused = (proj_name == "fused_wqa_wkv")
mod.quant_method = CuTeDSLNvfp4Method(is_fused=is_fused)
cutedsl_installed += 1
_shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None
from tqdm import tqdm
for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (upcast)NVFP4→BF16 attn projs", unit="layer"):
attn = layer.attn
# Compressor: still needs BF16 reconstruction
# Compressor: BF16 reconstruction (unchanged)
mla_attn = getattr(attn, "mla_attn", None)
if mla_attn is not None:
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
@@ -1794,49 +1745,147 @@ class DeepseekV4Model(nn.Module):
if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"):
compressor_converted += self._reconstruct_compressor_weight(
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer", _shard_index=_shard_index)
# Shared expert: install CuTeDSL shared expert runner
ffn = layer.ffn
if hasattr(ffn, 'shared_experts') and ffn.shared_experts is not None:
swiglu_limit = ffn.swiglu_limit if hasattr(ffn, 'swiglu_limit') else None
se = ffn.shared_experts
if self._install_shared_expert_runner(se, swiglu_limit, layer_idx):
shared_expert_installed += 1
def _dequant_nvfp4_to_bf16(self, mod, e2m1_lut):
"""Dequantize NVFP4 weight to bf16 for normal .forward() path."""
w_uint8 = mod.weight.data
device = w_uint8.device
w_bf16 = self._unpack_nvfp4_to_bf16(w_uint8, e2m1_lut, device)
# Dequantize with scales
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
block_scale = self._block_scale_to_float32(mod.weight_scale.data)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
block_scale_expanded = block_scale.unsqueeze(-1).expand(
-1, -1, block_size
).reshape(w_bf16.shape)
else:
block_scale_expanded = block_scale
global_scale = mod.weight_scale_2.data.max().item()
input_scale = (
mod.input_scale.data.max().item()
if hasattr(mod, "input_scale")
else 1.0
)
# NOTE: input_scale is for ACTIVATIONS, not weights.
# Weight dequant = e2m1 * block_scale * global_scale (NO input_scale)
w_dequant = w_bf16.float() * block_scale_expanded * global_scale
w_dequant = w_dequant.to(torch.bfloat16)
def _install_shared_expert_runner(self, se_mlp, swiglu_limit: float | None, layer_idx: int) -> bool:
"""Install CuTeDSL shared expert runner on a DeepseekV4MLP.
Extracts gate_up and down NVFP4 weights, creates
CuTeDSLSharedExpertRunner, and replaces the MLP's forward
with the fused L1+SiLU+L2 pipeline.
"""
from cutedsl.shared_expert_pipeline import CuTeDSLSharedExpertRunner
gate_up = se_mlp.gate_up_proj
down = se_mlp.down_proj
# Check that both projections have NVFP4 weights
if not (hasattr(gate_up, "weight") and hasattr(down, "weight")):
return False
if gate_up.weight.dtype not in (torch.uint8, torch.int8):
return False
if down.weight.dtype not in (torch.uint8, torch.int8):
return False
device = gate_up.weight.device
hidden_size = gate_up.weight.shape[1] * 2 # 2 FP4 per uint8
intermediate_size_2x = gate_up.weight.shape[0] # gate + up stacked
intermediate_size = intermediate_size_2x // 2
# ── L1: gate_up (MergedColumnParallelLinear, gate + up fused) ──
l1_w_uint8 = gate_up.weight.data # (2*intermediate, hidden//2) uint8
l1_sf = gate_up.weight_scale.data # (2*intermediate, hidden//16) float8
l1_gs_data = gate_up.weight_scale_2.data # float32 [2] (gate, up)
# uint8 → float4_e2m1fn_x2, permute to (K_packed, N)
l1_w_fp4 = l1_w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# Block scales: (N, K_sf) → (K_sf, N)
if l1_sf.dtype != torch.float8_e4m3fn:
l1_sf = l1_sf.to(torch.float8_e4m3fn)
l1_sf = l1_sf.permute(1, 0).contiguous()
# Dual global scales: normalize to max, fold ratio into block scales
l1_gs1 = l1_gs_data[0].item()
l1_gs2 = l1_gs_data[1].item()
l1_gs = max(l1_gs1, l1_gs2)
if l1_gs1 != l1_gs2:
l1_sf_f32 = l1_sf.float()
# After permute to (K_sf, N): first intermediate rows are gate, then up
l1_sf_f32[:, :intermediate_size] *= (l1_gs1 / l1_gs)
l1_sf_f32[:, intermediate_size:] *= (l1_gs2 / l1_gs)
l1_sf = l1_sf_f32.to(torch.float8_e4m3fn)
# ── L2: down (RowParallelLinear, single projection) ──
l2_w_uint8 = down.weight.data # (hidden, intermediate//2) uint8
l2_sf = down.weight_scale.data # (hidden, intermediate//16) float8
l2_gs = down.weight_scale_2.data.max().item() # float32 scalar
# uint8 → float4_e2m1fn_x2, permute to (K_packed, N)
l2_w_fp4 = l2_w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# Block scales: (N, K_sf) → (K_sf, N)
if l2_sf.dtype != torch.float8_e4m3fn:
l2_sf = l2_sf.to(torch.float8_e4m3fn)
l2_sf = l2_sf.permute(1, 0).contiguous()
# Create runner, set weights, finalize
runner = CuTeDSLSharedExpertRunner(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
device=device,
swiglu_limit=swiglu_limit if swiglu_limit is not None else 10.0,
)
runner.l1_fp4 = [l1_w_fp4]
runner.l1_sf = [l1_sf]
runner.l1_gs = [l1_gs]
runner.l2_fp4 = [l2_w_fp4]
runner.l2_sf = [l2_sf]
runner.l2_gs = [l2_gs]
runner.finalize_weights()
# Warmup: compute activation global scales
with torch.no_grad():
sample = torch.randn(min(8, 256), hidden_size,
dtype=torch.bfloat16, device=device) * 2.0
runner.compute_activation_global_scales(sample)
# Replace the MLP's forward with the runner
se_mlp._cutedsl_runner = runner
# Monkey-patch forward to use the CuTeDSL runner
original_cls = type(se_mlp)
def _cutedsl_forward(self, x):
output = self._cutedsl_runner.run(x)
# Down_proj with reduce_results may need all-reduce handled
# by RowParallelLinear. Since we bypassed it, check if we need
# to all-reduce manually.
if hasattr(self, '_needs_tp_reduce') and self._needs_tp_reduce:
from vllm.distributed import tensor_model_parallel_all_reduce
output = tensor_model_parallel_all_reduce(output)
return output
import types
se_mlp.forward = types.MethodType(_cutedsl_forward, se_mlp)
# Check if down_proj needs TP all-reduce
# reduce_results=True means the original RowParallelLinear would all-reduce
if hasattr(down, 'reduce_results') and down.reduce_results and getattr(down, 'tp_size', 1) > 1:
se_mlp._needs_tp_reduce = True
else:
w_dequant = w_bf16
# Free source tensors eagerly to avoid holding uint8+bf16+fp32 simultaneously
del w_uint8, w_bf16
mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False)
del w_dequant
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale",
"weight_scale_inv"):
if hasattr(mod, attr):
delattr(mod, attr)
se_mlp._needs_tp_reduce = False
# Free NVFP4 params from gate_up and down (replace with dummy BF16)
for mod in [gate_up, down]:
out_dim = mod.weight.shape[0]
in_dim = mod.weight.shape[1] * 2
mod.weight = torch.nn.Parameter(
torch.zeros(out_dim, in_dim, dtype=torch.bfloat16,
device=device),
requires_grad=False,
)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale",
"input_global_scale", "input_global_scale_inv",
"weight_global_scale", "alpha", "weight_scale_inv"):
if hasattr(mod, attr):
try:
delattr(mod, attr)
except Exception:
pass
return True
def _convert_nvfp4_to_fp8(self, mod, e2m1_lut, fp8_max):
"""Convert NVFP4 weight to FP8 for fp8_einsum path (wo_a only).
@@ -2407,59 +2456,6 @@ class DeepseekV4ForCausalLM(nn.Module):
del residual, fn, hc_scale, hc_base, x, post_mix, comb_mix
torch.cuda.empty_cache()
def _post_quant_fix(self) -> None:
"""Called by vLLM's process_weights_after_loading AFTER quant methods
have set up their attributes. Dequantizes NVFP4 weights to BF16 for
attention projections and shared experts because
FlashInferCutlassNvFp4LinearKernel uses broken input_global_scale_inv."""
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
fixed = 0
for layer_idx, layer in enumerate(self.model.layers):
attn = layer.attn
# Attention projections
for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
continue
self.model._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale",
"input_global_scale", "input_global_scale_inv",
"weight_global_scale", "alpha", "weight_scale_inv"):
if hasattr(mod, attr):
try: delattr(mod, attr)
except: pass
fixed += 1
# Shared expert projections (also NVFP4 with broken input_scale)
ffn = layer.ffn
if hasattr(ffn, 'shared_experts'):
for proj_name in ["gate_up_proj", "down_proj"]:
se = ffn.shared_experts
if not hasattr(se, proj_name):
continue
mod = getattr(se, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
continue
self.model._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale",
"input_global_scale", "input_global_scale_inv",
"weight_global_scale", "alpha", "weight_scale_inv"):
if hasattr(mod, attr):
try: delattr(mod, attr)
except: pass
fixed += 1
print(f" [CLAWMINE] Post-quant fix: {fixed} attention projections → BF16 ✓", flush=True)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
def _register_post_quant_fix(self) -> None:
"""No-op — we use _post_quant_fix() called from process_weights_after_loading."""
pass

View File

@@ -121,12 +121,6 @@ def process_weights_after_loading(
with device_loading_context(module, target_device):
module.process_weights_after_loading(model_config.dtype)
# Needed for torchao model reloading via model.reload_weights
# @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights`
# Custom: allow models to run post-quant-init fixes
if hasattr(model, '_post_quant_fix'):
model._post_quant_fix()
if model_config.quantization == "torchao":
set_torchao_reload_attrs(model, model_config)