Wire CuTeDSL kernels into vLLM: replace all BF16 dequant with native NVFP4
- CuTeDSLNvfp4Method: custom quant method that creates CuTeDSL runners during process_weights_after_loading, then swaps to CuTeDSLNvfp4LinearMethod for forward dispatch - Attention projections (fused_wqa_wkv, wq_b, wo_b) now route through CuTeDSLNvfp4Linear (cosine 0.992-0.996 vs BF16 reference) - Shared expert now uses CuTeDSLSharedExpertRunner (cosine 0.992 vs BF16) with monkey-patched forward for fused L1+SiLU+L2 pipeline - Deleted all BF16 dequant code (_dequant_nvfp4_to_bf16, _post_quant_fix, input_scale fixes) - Deleted _post_quant_fix hook from utils.py - Fixed SwiGLU clamp: gate clamped BEFORE SiLU (matching SiluAndMulWithClamp) - Cleaned up all debug prints - Updated Dockerfile with new kernel files
This commit is contained in:
@@ -35,6 +35,9 @@ ARG VLLM_LOADER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/
|
||||
COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py
|
||||
COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py
|
||||
COPY vllm/nvfp4_cutedsl.py ${VLLM_MODELS_DIR}/nvfp4_cutedsl.py
|
||||
COPY vllm/cutedsl_quant_method.py ${VLLM_MODELS_DIR}/cutedsl_quant_method.py
|
||||
COPY cutedsl/nvfp4_linear.py /root/nvfp4-megamoe-kernel/cutedsl/nvfp4_linear.py
|
||||
COPY cutedsl/shared_expert_pipeline.py /root/nvfp4-megamoe-kernel/cutedsl/shared_expert_pipeline.py
|
||||
COPY vllm/patches/utils.py ${VLLM_LOADER_DIR}/utils.py
|
||||
|
||||
RUN sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \
|
||||
|
||||
@@ -62,9 +62,8 @@ class CuTeDSLNvfp4Linear:
|
||||
self._gsa_buf = None
|
||||
self._buffers_allocated = False
|
||||
|
||||
import os
|
||||
print(f"[CLAWMINE] Nvfp4Linear init: in={in_features} out={out_features} "
|
||||
f"max_tokens={max_num_tokens} pid={os.getpid()}", flush=True)
|
||||
print(f" Nvfp4Linear init: in={in_features} out={out_features} "
|
||||
f"max_tokens={max_num_tokens}", flush=True)
|
||||
|
||||
def finalize_weights(self):
|
||||
"""Process weights for CuTeDSL GEMM."""
|
||||
|
||||
@@ -83,9 +83,8 @@ class CuTeDSLSharedExpertRunner:
|
||||
self._expert_offsets_buf = None
|
||||
self._buffers_allocated = False
|
||||
|
||||
import os
|
||||
print(f"[CLAWMINE] SharedExpert init: hidden={hidden_size} intermediate={intermediate_size} "
|
||||
f"max_tokens={max_num_tokens} pid={os.getpid()}", flush=True)
|
||||
print(f" SharedExpert init: hidden={hidden_size} intermediate={intermediate_size} "
|
||||
f"max_tokens={max_num_tokens}", flush=True)
|
||||
|
||||
def set_swiglu_limit(self, limit: float):
|
||||
self.swiglu_limit = limit
|
||||
@@ -192,11 +191,10 @@ class CuTeDSLSharedExpertRunner:
|
||||
if l1_out is not None and not torch.isnan(l1_out).any():
|
||||
gate = l1_out[:, :self.intermediate_size]
|
||||
up = l1_out[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
if self.swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self.swiglu_limit)
|
||||
gate = gate.clamp(max=self.swiglu_limit)
|
||||
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
|
||||
activated = gate_silu * up
|
||||
activated = torch.nn.functional.silu(gate) * up
|
||||
_, _, l2_gs = quantize_to_nvfp4(activated)
|
||||
self._l2_activation_global_scale = l2_gs
|
||||
|
||||
@@ -288,11 +286,11 @@ class CuTeDSLSharedExpertRunner:
|
||||
|
||||
gate = l1_out[:, :self.intermediate_size]
|
||||
up = l1_out[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
if self.swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self.swiglu_limit)
|
||||
# Match SiluAndMulWithClamp: clamp gate BEFORE silu, clamp up to [-limit, limit]
|
||||
gate = gate.clamp(max=self.swiglu_limit)
|
||||
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
|
||||
intermediate = gate_silu * up
|
||||
intermediate = torch.nn.functional.silu(gate) * up
|
||||
|
||||
output = self._run_l2(intermediate)
|
||||
return output
|
||||
|
||||
135
vllm/cutedsl_quant_method.py
Normal file
135
vllm/cutedsl_quant_method.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""CuTeDSL NVFP4 Quantization Method for vLLM
|
||||
|
||||
Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM.
|
||||
After process_weights_after_loading, the module's quant_method is swapped
|
||||
to CuTeDSLNvfp4LinearMethod which routes forward() through CuTeDSL.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
|
||||
|
||||
class CuTeDSLNvfp4Method(LinearMethodBase):
|
||||
"""Pre-processing quant method that sets up CuTeDSL runners.
|
||||
|
||||
Installed on NVFP4 linear layers before process_weights_after_loading.
|
||||
When vLLM calls process_weights_after_loading, this method:
|
||||
1. Reads NVFP4 weights (uint8, float8 block scales, float32 global scales)
|
||||
2. Converts to CuTeDSL format
|
||||
3. Creates CuTeDSLNvfp4Linear runner
|
||||
4. Stores runner on the module
|
||||
5. Frees original weight/scale params
|
||||
6. Replaces quant_method with CuTeDSLNvfp4LinearMethod
|
||||
"""
|
||||
|
||||
def __init__(self, is_fused: bool = False):
|
||||
"""
|
||||
Args:
|
||||
is_fused: True for MergedColumnParallelLinear with two sub-projections
|
||||
(e.g., fused_wqa_wkv with q_a + kv, or gate_up_proj with gate + up).
|
||||
Handles dual weight_scale_2 the same way as MoE L1:
|
||||
normalize to max(gs1, gs2), fold ratio into block scales.
|
||||
"""
|
||||
self.is_fused = is_fused
|
||||
|
||||
def create_weights(self, layer, input_size_per_partition,
|
||||
output_partition_sizes, input_size, output_size,
|
||||
params_dtype, **extra_weight_attrs):
|
||||
# We don't create weights — ModelOptNvFp4LinearMethod already did that.
|
||||
# This method is only installed after weight loading.
|
||||
pass
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
|
||||
w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1
|
||||
device = w_uint8.device
|
||||
out_features = w_uint8.shape[0]
|
||||
in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8
|
||||
|
||||
# Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N)
|
||||
w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
|
||||
# Block scales: (N, K_sf) → (K_sf, N)
|
||||
sf = layer.weight_scale.data
|
||||
if sf.dtype != torch.float8_e4m3fn:
|
||||
sf = sf.to(torch.float8_e4m3fn)
|
||||
sf = sf.permute(1, 0).contiguous()
|
||||
|
||||
# Global scale
|
||||
weight_scale_2 = layer.weight_scale_2.data
|
||||
if self.is_fused and weight_scale_2.numel() == 2:
|
||||
# Dual global scales (fused_wqa_wkv: q_a + kv, gate_up: gate + up)
|
||||
gs1 = weight_scale_2[0].item()
|
||||
gs2 = weight_scale_2[1].item()
|
||||
gs = max(gs1, gs2)
|
||||
|
||||
# Fold ratio into block scales via float32 round-trip
|
||||
if gs1 != gs2:
|
||||
sf_f32 = sf.float()
|
||||
# After permute to (K_sf, N): first sub-projection's output
|
||||
# columns, then second sub-projection's output columns
|
||||
logical_widths = getattr(layer, 'logical_widths', None)
|
||||
if logical_widths is not None and len(logical_widths) == 2:
|
||||
split_point = logical_widths[0]
|
||||
else:
|
||||
split_point = out_features // 2
|
||||
sf_f32[:, :split_point] *= (gs1 / gs)
|
||||
sf_f32[:, split_point:] *= (gs2 / gs)
|
||||
sf = sf_f32.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
gs = weight_scale_2.max().item()
|
||||
|
||||
# Create CuTeDSL runner
|
||||
runner = CuTeDSLNvfp4Linear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
device=device,
|
||||
)
|
||||
runner.fp4 = [w_fp4]
|
||||
runner.sf = [sf]
|
||||
runner.gs = [gs]
|
||||
runner.finalize_weights()
|
||||
|
||||
# Store runner on the module
|
||||
layer._cutedsl_runner = runner
|
||||
|
||||
# Warmup: compute activation global scale from sample data
|
||||
with torch.no_grad():
|
||||
sample = torch.randn(min(8, 256), in_features,
|
||||
dtype=torch.bfloat16, device=device) * 2.0
|
||||
runner.compute_activation_global_scale(sample)
|
||||
|
||||
# Replace weight with dummy BF16 (needed by vLLM module introspection)
|
||||
layer.weight = torch.nn.Parameter(
|
||||
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Free original NVFP4 params
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
||||
"input_global_scale", "input_global_scale_inv",
|
||||
"weight_global_scale", "alpha", "weight_scale_inv"):
|
||||
if hasattr(layer, attr):
|
||||
try:
|
||||
delattr(layer, attr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Swap quant method to the forward-only one
|
||||
layer.quant_method = CuTeDSLNvfp4LinearMethod()
|
||||
|
||||
def apply(self, layer, x, bias=None):
|
||||
raise NotImplementedError(
|
||||
"CuTeDSLNvfp4Method should be replaced by "
|
||||
"CuTeDSLNvfp4LinearMethod after process_weights_after_loading"
|
||||
)
|
||||
|
||||
|
||||
class CuTeDSLNvfp4LinearMethod(LinearMethodBase):
|
||||
"""Forward path: BF16 input → CuTeDSL NVFP4 GEMM → BF16 output."""
|
||||
|
||||
def apply(self, layer, x: torch.Tensor, bias=None) -> torch.Tensor:
|
||||
return layer._cutedsl_runner(x)
|
||||
@@ -87,10 +87,6 @@ class CuTeDSLMoERunner:
|
||||
self._max_chunks_per_expert = cutedsl_ceil_div(
|
||||
self.max_num_tokens * self.top_k, self.num_experts * 128
|
||||
)
|
||||
import os
|
||||
print(f"[CLAWMINE] Runner init: max_num_tokens={self.max_num_tokens} top_k={self.top_k} "
|
||||
f"num_experts={self.num_experts} max_chunks={self._max_chunks_per_expert} "
|
||||
f"pid={os.getpid()}", flush=True)
|
||||
self._buffers_allocated = False
|
||||
|
||||
def set_swiglu_limit(self, limit: float | None):
|
||||
|
||||
@@ -1683,104 +1683,55 @@ class DeepseekV4Model(nn.Module):
|
||||
layer.ffn.finalize_mega_moe_weights()
|
||||
|
||||
def _convert_nvfp4_post_load(self):
|
||||
"""Post-load conversion of NVFP4 weights for vLLM compatibility.
|
||||
|
||||
Fixes the attention input_scale values BEFORE
|
||||
process_weights_after_loading runs. The checkpoint input_scale
|
||||
values are wrong and cause NaN during activation quantization.
|
||||
We compute correct values by dequantizing to BF16 temporarily
|
||||
and running a warmup forward.
|
||||
|
||||
wo_a is converted to FP8 for fp8_einsum (no input_scale needed).
|
||||
Compressor weights are reconstructed from checkpoint sub-weights.
|
||||
"""Post-load setup of CuTeDSL NVFP4 runners for attention and shared experts.
|
||||
|
||||
Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM.
|
||||
For attention projections (fused_wqa_wkv, wq_b, wo_b), installs
|
||||
CuTeDSLNvfp4Method which creates CuTeDSL runners during
|
||||
process_weights_after_loading.
|
||||
|
||||
For shared experts, creates CuTeDSLSharedExpertRunner which handles
|
||||
the full L1 (gate_up) + SiLU + L2 (down) pipeline.
|
||||
|
||||
wo_a is converted to FP8 for fp8_einsum (unchanged).
|
||||
Compressor weights are reconstructed from checkpoint sub-weights (unchanged).
|
||||
"""
|
||||
fp8_proj_names = {"wo_a"}
|
||||
from vllm.model_executor.models.cutedsl_quant_method import CuTeDSLNvfp4Method
|
||||
|
||||
fp8_converted = 0
|
||||
compressor_converted = 0
|
||||
input_scale_fixes = 0
|
||||
cutedsl_installed = 0
|
||||
shared_expert_installed = 0
|
||||
|
||||
_shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None
|
||||
|
||||
from tqdm import tqdm
|
||||
for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (fix)NVFP4 attn input_scale", unit="layer"):
|
||||
for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" NVFP4→CuTeDSL setup", unit="layer"):
|
||||
attn = layer.attn
|
||||
|
||||
|
||||
# FP8 conversion: wo_a (used by fp8_einsum, no input_scale)
|
||||
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
for proj_name in fp8_proj_names:
|
||||
if not hasattr(attn, proj_name):
|
||||
continue
|
||||
mod = getattr(attn, proj_name)
|
||||
if not hasattr(mod, "weight"):
|
||||
continue
|
||||
if mod.weight.dtype in (torch.uint8, torch.int8):
|
||||
if hasattr(attn, "wo_a") and hasattr(attn.wo_a, "weight"):
|
||||
if attn.wo_a.weight.dtype in (torch.uint8, torch.int8):
|
||||
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
|
||||
self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX)
|
||||
self._convert_nvfp4_to_fp8(attn.wo_a, E2M1_LUT, FP8_MAX)
|
||||
fp8_converted += 1
|
||||
|
||||
# Fix input_scale for attention NVFP4 projections
|
||||
# process_weights_after_loading reads input_scale and computes
|
||||
# input_global_scale_inv = 1/input_scale. By fixing input_scale
|
||||
# here, the quant method will propagate the correct value.
|
||||
|
||||
# Install CuTeDSL quant method on attention NVFP4 projections.
|
||||
# When vLLM calls process_weights_after_loading, CuTeDSLNvfp4Method
|
||||
# will read the NVFP4 weights, create CuTeDSL runners, and swap
|
||||
# the quant method to CuTeDSLNvfp4LinearMethod.
|
||||
for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]:
|
||||
if not hasattr(attn, proj_name):
|
||||
continue
|
||||
mod = getattr(attn, proj_name)
|
||||
if not hasattr(mod, "input_scale"):
|
||||
continue
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
|
||||
# Temporarily dequantize weight to BF16 for warmup
|
||||
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
|
||||
w_uint8 = mod.weight.data
|
||||
w_bf16_unpacked = self._unpack_nvfp4_to_bf16(w_uint8, E2M1_LUT, w_uint8.device)
|
||||
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
|
||||
block_scale = self._block_scale_to_float32(mod.weight_scale.data)
|
||||
if block_scale.dim() == 2 and w_bf16_unpacked.dim() == 2:
|
||||
block_size = w_bf16_unpacked.shape[1] // block_scale.shape[1]
|
||||
block_scale_expanded = block_scale.unsqueeze(-1).expand(-1, -1, block_size).reshape(w_bf16_unpacked.shape)
|
||||
else:
|
||||
block_scale_expanded = block_scale
|
||||
global_scale = mod.weight_scale_2.data.max().item()
|
||||
w_bf16_dequant = (w_bf16_unpacked.float() * block_scale_expanded * global_scale).to(torch.bfloat16)
|
||||
else:
|
||||
w_bf16_dequant = w_bf16_unpacked
|
||||
|
||||
# Compute correct input_scale from warmup
|
||||
with torch.no_grad():
|
||||
in_features = w_bf16_dequant.shape[-1]
|
||||
dummy_input = torch.randn(256, in_features, dtype=torch.bfloat16, device=mod.weight.device) * 2.0
|
||||
ref_output = torch.nn.functional.linear(dummy_input, w_bf16_dequant)
|
||||
act_amax = ref_output.amax().item()
|
||||
del w_bf16_unpacked, w_bf16_dequant, ref_output
|
||||
|
||||
# input_scale should be 1/(amax * headroom) — this is the
|
||||
# activation global scale that maps activations to FP4 range.
|
||||
# process_weights_after_loading computes:
|
||||
# input_global_scale_inv = input_scale.max()
|
||||
# input_global_scale = 1 / input_global_scale_inv
|
||||
headroom = 1.2
|
||||
new_input_scale = 1.0 / (act_amax * headroom) if act_amax > 0 else mod.input_scale.data
|
||||
|
||||
if layer_idx == 0:
|
||||
old_input_scale = mod.input_scale.data.item() if mod.input_scale.data.numel() == 1 else mod.input_scale.data.max().item()
|
||||
print(f"[CLAWMINE] Layer 0: {proj_name} input_scale: {old_input_scale:.8f} → {new_input_scale:.8f} (act_amax={act_amax:.4f})")
|
||||
|
||||
mod.input_scale = torch.nn.Parameter(
|
||||
torch.tensor([new_input_scale] * mod.input_scale.data.numel(), dtype=mod.input_scale.data.dtype, device=mod.input_scale.data.device),
|
||||
requires_grad=False
|
||||
)
|
||||
input_scale_fixes += 1
|
||||
is_fused = (proj_name == "fused_wqa_wkv")
|
||||
mod.quant_method = CuTeDSLNvfp4Method(is_fused=is_fused)
|
||||
cutedsl_installed += 1
|
||||
|
||||
_shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None
|
||||
|
||||
from tqdm import tqdm
|
||||
for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (upcast)NVFP4→BF16 attn projs", unit="layer"):
|
||||
attn = layer.attn
|
||||
|
||||
|
||||
# Compressor: still needs BF16 reconstruction
|
||||
# Compressor: BF16 reconstruction (unchanged)
|
||||
mla_attn = getattr(attn, "mla_attn", None)
|
||||
if mla_attn is not None:
|
||||
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
|
||||
@@ -1794,49 +1745,147 @@ class DeepseekV4Model(nn.Module):
|
||||
if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"):
|
||||
compressor_converted += self._reconstruct_compressor_weight(
|
||||
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer", _shard_index=_shard_index)
|
||||
|
||||
|
||||
# Shared expert: install CuTeDSL shared expert runner
|
||||
ffn = layer.ffn
|
||||
if hasattr(ffn, 'shared_experts') and ffn.shared_experts is not None:
|
||||
swiglu_limit = ffn.swiglu_limit if hasattr(ffn, 'swiglu_limit') else None
|
||||
se = ffn.shared_experts
|
||||
if self._install_shared_expert_runner(se, swiglu_limit, layer_idx):
|
||||
shared_expert_installed += 1
|
||||
|
||||
|
||||
|
||||
def _dequant_nvfp4_to_bf16(self, mod, e2m1_lut):
|
||||
"""Dequantize NVFP4 weight to bf16 for normal .forward() path."""
|
||||
w_uint8 = mod.weight.data
|
||||
device = w_uint8.device
|
||||
w_bf16 = self._unpack_nvfp4_to_bf16(w_uint8, e2m1_lut, device)
|
||||
|
||||
# Dequantize with scales
|
||||
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
|
||||
block_scale = self._block_scale_to_float32(mod.weight_scale.data)
|
||||
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
||||
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
||||
block_scale_expanded = block_scale.unsqueeze(-1).expand(
|
||||
-1, -1, block_size
|
||||
).reshape(w_bf16.shape)
|
||||
else:
|
||||
block_scale_expanded = block_scale
|
||||
global_scale = mod.weight_scale_2.data.max().item()
|
||||
input_scale = (
|
||||
mod.input_scale.data.max().item()
|
||||
if hasattr(mod, "input_scale")
|
||||
else 1.0
|
||||
)
|
||||
# NOTE: input_scale is for ACTIVATIONS, not weights.
|
||||
# Weight dequant = e2m1 * block_scale * global_scale (NO input_scale)
|
||||
w_dequant = w_bf16.float() * block_scale_expanded * global_scale
|
||||
w_dequant = w_dequant.to(torch.bfloat16)
|
||||
def _install_shared_expert_runner(self, se_mlp, swiglu_limit: float | None, layer_idx: int) -> bool:
|
||||
"""Install CuTeDSL shared expert runner on a DeepseekV4MLP.
|
||||
|
||||
Extracts gate_up and down NVFP4 weights, creates
|
||||
CuTeDSLSharedExpertRunner, and replaces the MLP's forward
|
||||
with the fused L1+SiLU+L2 pipeline.
|
||||
"""
|
||||
from cutedsl.shared_expert_pipeline import CuTeDSLSharedExpertRunner
|
||||
|
||||
gate_up = se_mlp.gate_up_proj
|
||||
down = se_mlp.down_proj
|
||||
|
||||
# Check that both projections have NVFP4 weights
|
||||
if not (hasattr(gate_up, "weight") and hasattr(down, "weight")):
|
||||
return False
|
||||
if gate_up.weight.dtype not in (torch.uint8, torch.int8):
|
||||
return False
|
||||
if down.weight.dtype not in (torch.uint8, torch.int8):
|
||||
return False
|
||||
|
||||
device = gate_up.weight.device
|
||||
hidden_size = gate_up.weight.shape[1] * 2 # 2 FP4 per uint8
|
||||
intermediate_size_2x = gate_up.weight.shape[0] # gate + up stacked
|
||||
intermediate_size = intermediate_size_2x // 2
|
||||
|
||||
# ── L1: gate_up (MergedColumnParallelLinear, gate + up fused) ──
|
||||
l1_w_uint8 = gate_up.weight.data # (2*intermediate, hidden//2) uint8
|
||||
l1_sf = gate_up.weight_scale.data # (2*intermediate, hidden//16) float8
|
||||
l1_gs_data = gate_up.weight_scale_2.data # float32 [2] (gate, up)
|
||||
|
||||
# uint8 → float4_e2m1fn_x2, permute to (K_packed, N)
|
||||
l1_w_fp4 = l1_w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
|
||||
# Block scales: (N, K_sf) → (K_sf, N)
|
||||
if l1_sf.dtype != torch.float8_e4m3fn:
|
||||
l1_sf = l1_sf.to(torch.float8_e4m3fn)
|
||||
l1_sf = l1_sf.permute(1, 0).contiguous()
|
||||
|
||||
# Dual global scales: normalize to max, fold ratio into block scales
|
||||
l1_gs1 = l1_gs_data[0].item()
|
||||
l1_gs2 = l1_gs_data[1].item()
|
||||
l1_gs = max(l1_gs1, l1_gs2)
|
||||
if l1_gs1 != l1_gs2:
|
||||
l1_sf_f32 = l1_sf.float()
|
||||
# After permute to (K_sf, N): first intermediate rows are gate, then up
|
||||
l1_sf_f32[:, :intermediate_size] *= (l1_gs1 / l1_gs)
|
||||
l1_sf_f32[:, intermediate_size:] *= (l1_gs2 / l1_gs)
|
||||
l1_sf = l1_sf_f32.to(torch.float8_e4m3fn)
|
||||
|
||||
# ── L2: down (RowParallelLinear, single projection) ──
|
||||
l2_w_uint8 = down.weight.data # (hidden, intermediate//2) uint8
|
||||
l2_sf = down.weight_scale.data # (hidden, intermediate//16) float8
|
||||
l2_gs = down.weight_scale_2.data.max().item() # float32 scalar
|
||||
|
||||
# uint8 → float4_e2m1fn_x2, permute to (K_packed, N)
|
||||
l2_w_fp4 = l2_w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
|
||||
# Block scales: (N, K_sf) → (K_sf, N)
|
||||
if l2_sf.dtype != torch.float8_e4m3fn:
|
||||
l2_sf = l2_sf.to(torch.float8_e4m3fn)
|
||||
l2_sf = l2_sf.permute(1, 0).contiguous()
|
||||
|
||||
# Create runner, set weights, finalize
|
||||
runner = CuTeDSLSharedExpertRunner(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
device=device,
|
||||
swiglu_limit=swiglu_limit if swiglu_limit is not None else 10.0,
|
||||
)
|
||||
runner.l1_fp4 = [l1_w_fp4]
|
||||
runner.l1_sf = [l1_sf]
|
||||
runner.l1_gs = [l1_gs]
|
||||
runner.l2_fp4 = [l2_w_fp4]
|
||||
runner.l2_sf = [l2_sf]
|
||||
runner.l2_gs = [l2_gs]
|
||||
runner.finalize_weights()
|
||||
|
||||
# Warmup: compute activation global scales
|
||||
with torch.no_grad():
|
||||
sample = torch.randn(min(8, 256), hidden_size,
|
||||
dtype=torch.bfloat16, device=device) * 2.0
|
||||
runner.compute_activation_global_scales(sample)
|
||||
|
||||
# Replace the MLP's forward with the runner
|
||||
se_mlp._cutedsl_runner = runner
|
||||
|
||||
# Monkey-patch forward to use the CuTeDSL runner
|
||||
original_cls = type(se_mlp)
|
||||
|
||||
def _cutedsl_forward(self, x):
|
||||
output = self._cutedsl_runner.run(x)
|
||||
# Down_proj with reduce_results may need all-reduce handled
|
||||
# by RowParallelLinear. Since we bypassed it, check if we need
|
||||
# to all-reduce manually.
|
||||
if hasattr(self, '_needs_tp_reduce') and self._needs_tp_reduce:
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
output = tensor_model_parallel_all_reduce(output)
|
||||
return output
|
||||
|
||||
import types
|
||||
se_mlp.forward = types.MethodType(_cutedsl_forward, se_mlp)
|
||||
|
||||
# Check if down_proj needs TP all-reduce
|
||||
# reduce_results=True means the original RowParallelLinear would all-reduce
|
||||
if hasattr(down, 'reduce_results') and down.reduce_results and getattr(down, 'tp_size', 1) > 1:
|
||||
se_mlp._needs_tp_reduce = True
|
||||
else:
|
||||
w_dequant = w_bf16
|
||||
|
||||
# Free source tensors eagerly to avoid holding uint8+bf16+fp32 simultaneously
|
||||
del w_uint8, w_bf16
|
||||
mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False)
|
||||
del w_dequant
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
||||
"weight_scale_inv"):
|
||||
if hasattr(mod, attr):
|
||||
delattr(mod, attr)
|
||||
se_mlp._needs_tp_reduce = False
|
||||
|
||||
# Free NVFP4 params from gate_up and down (replace with dummy BF16)
|
||||
for mod in [gate_up, down]:
|
||||
out_dim = mod.weight.shape[0]
|
||||
in_dim = mod.weight.shape[1] * 2
|
||||
mod.weight = torch.nn.Parameter(
|
||||
torch.zeros(out_dim, in_dim, dtype=torch.bfloat16,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
||||
"input_global_scale", "input_global_scale_inv",
|
||||
"weight_global_scale", "alpha", "weight_scale_inv"):
|
||||
if hasattr(mod, attr):
|
||||
try:
|
||||
delattr(mod, attr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
def _convert_nvfp4_to_fp8(self, mod, e2m1_lut, fp8_max):
|
||||
"""Convert NVFP4 weight to FP8 for fp8_einsum path (wo_a only).
|
||||
@@ -2407,59 +2456,6 @@ class DeepseekV4ForCausalLM(nn.Module):
|
||||
del residual, fn, hc_scale, hc_base, x, post_mix, comb_mix
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _post_quant_fix(self) -> None:
|
||||
"""Called by vLLM's process_weights_after_loading AFTER quant methods
|
||||
have set up their attributes. Dequantizes NVFP4 weights to BF16 for
|
||||
attention projections and shared experts because
|
||||
FlashInferCutlassNvFp4LinearKernel uses broken input_global_scale_inv."""
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
|
||||
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
|
||||
fixed = 0
|
||||
for layer_idx, layer in enumerate(self.model.layers):
|
||||
attn = layer.attn
|
||||
# Attention projections
|
||||
for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]:
|
||||
if not hasattr(attn, proj_name):
|
||||
continue
|
||||
mod = getattr(attn, proj_name)
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
self.model._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
||||
"input_global_scale", "input_global_scale_inv",
|
||||
"weight_global_scale", "alpha", "weight_scale_inv"):
|
||||
if hasattr(mod, attr):
|
||||
try: delattr(mod, attr)
|
||||
except: pass
|
||||
fixed += 1
|
||||
|
||||
# Shared expert projections (also NVFP4 with broken input_scale)
|
||||
ffn = layer.ffn
|
||||
if hasattr(ffn, 'shared_experts'):
|
||||
for proj_name in ["gate_up_proj", "down_proj"]:
|
||||
se = ffn.shared_experts
|
||||
if not hasattr(se, proj_name):
|
||||
continue
|
||||
mod = getattr(se, proj_name)
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
self.model._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
||||
"input_global_scale", "input_global_scale_inv",
|
||||
"weight_global_scale", "alpha", "weight_scale_inv"):
|
||||
if hasattr(mod, attr):
|
||||
try: delattr(mod, attr)
|
||||
except: pass
|
||||
fixed += 1
|
||||
print(f" [CLAWMINE] Post-quant fix: {fixed} attention projections → BF16 ✓", flush=True)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
def _register_post_quant_fix(self) -> None:
|
||||
"""No-op — we use _post_quant_fix() called from process_weights_after_loading."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -121,12 +121,6 @@ def process_weights_after_loading(
|
||||
with device_loading_context(module, target_device):
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
# Needed for torchao model reloading via model.reload_weights
|
||||
# @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights`
|
||||
# Custom: allow models to run post-quant-init fixes
|
||||
if hasattr(model, '_post_quant_fix'):
|
||||
model._post_quant_fix()
|
||||
|
||||
if model_config.quantization == "torchao":
|
||||
set_torchao_reload_attrs(model, model_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user