Files
nvfp4-megamoe-kernel/vllm/cutedsl_quant_method.py

141 lines
5.5 KiB
Python

"""CuTeDSL NVFP4 Quantization Method for vLLM
Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM.
After process_weights_after_loading, the module's quant_method is swapped
to CuTeDSLNvfp4LinearMethod which routes forward() through CuTeDSL.
"""
import torch
from vllm.model_executor.layers.linear import LinearMethodBase
class CuTeDSLNvfp4Method(LinearMethodBase):
"""Pre-processing quant method that sets up CuTeDSL runners.
Installed on NVFP4 linear layers before process_weights_after_loading.
When vLLM calls process_weights_after_loading, this method:
1. Reads NVFP4 weights (uint8, float8 block scales, float32 global scales)
2. Converts to CuTeDSL format
3. Creates CuTeDSLNvfp4Linear runner
4. Stores runner on the module
5. Frees original weight/scale params
6. Replaces quant_method with CuTeDSLNvfp4LinearMethod
"""
def __init__(self, is_fused: bool = False):
"""
Args:
is_fused: True for MergedColumnParallelLinear with two sub-projections
(e.g., fused_wqa_wkv with q_a + kv, or gate_up_proj with gate + up).
Handles dual weight_scale_2 the same way as MoE L1:
normalize to max(gs1, gs2), fold ratio into block scales.
"""
self.is_fused = is_fused
def create_weights(self, layer, input_size_per_partition,
output_partition_sizes, input_size, output_size,
params_dtype, **extra_weight_attrs):
# We don't create weights — ModelOptNvFp4LinearMethod already did that.
# This method is only installed after weight loading.
pass
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1
device = w_uint8.device
out_features = w_uint8.shape[0]
in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8
# Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N)
w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# Block scales: (N, K_sf) → (K_sf, N)
sf = layer.weight_scale.data
if sf.dtype != torch.float8_e4m3fn:
sf = sf.to(torch.float8_e4m3fn)
sf = sf.permute(1, 0).contiguous()
# Global scale
weight_scale_2 = layer.weight_scale_2.data
if self.is_fused and weight_scale_2.numel() == 2:
# Dual global scales (fused_wqa_wkv: q_a + kv, gate_up: gate + up)
gs1 = weight_scale_2[0].item()
gs2 = weight_scale_2[1].item()
gs = max(gs1, gs2)
# Fold ratio into block scales via float32 round-trip
if gs1 != gs2:
sf_f32 = sf.float()
# After permute to (K_sf, N): first sub-projection's output
# columns, then second sub-projection's output columns
logical_widths = getattr(layer, 'logical_widths', None)
if logical_widths is not None and len(logical_widths) == 2:
split_point = logical_widths[0]
else:
split_point = out_features // 2
sf_f32[:, :split_point] *= (gs1 / gs)
sf_f32[:, split_point:] *= (gs2 / gs)
sf = sf_f32.to(torch.float8_e4m3fn)
else:
gs = weight_scale_2.max().item()
# Create CuTeDSL runner
runner = CuTeDSLNvfp4Linear(
in_features=in_features,
out_features=out_features,
device=device,
)
runner.fp4 = [w_fp4]
runner.sf = [sf]
runner.gs = [gs]
runner.finalize_weights()
# Store runner on the module
layer._cutedsl_runner = runner
# Warmup: compute activation global scale from sample data
with torch.no_grad():
sample = torch.randn(min(8, 256), in_features,
dtype=torch.bfloat16, device=device) * 2.0
runner.compute_activation_global_scale(sample)
# Replace weight with dummy BF16 (needed by vLLM module introspection)
layer.weight = torch.nn.Parameter(
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
device=device),
requires_grad=False,
)
# Free original NVFP4 params
for attr in ("weight_scale", "weight_scale_2", "input_scale",
"input_global_scale", "input_global_scale_inv",
"weight_global_scale", "alpha", "weight_scale_inv"):
if hasattr(layer, attr):
try:
delattr(layer, attr)
except Exception:
pass
# Swap quant method to the forward-only one
layer.quant_method = CuTeDSLNvfp4LinearMethod()
def apply(self, layer, x, bias=None):
raise NotImplementedError(
"CuTeDSLNvfp4Method should be replaced by "
"CuTeDSLNvfp4LinearMethod after process_weights_after_loading"
)
class CuTeDSLNvfp4LinearMethod(LinearMethodBase):
"""Forward path: BF16 input → CuTeDSL NVFP4 GEMM → BF16 output."""
def create_weights(self, layer, input_size_per_partition,
output_partition_sizes, input_size, output_size,
params_dtype, **extra_weight_attrs):
pass
def apply(self, layer, x: torch.Tensor, bias=None) -> torch.Tensor:
return layer._cutedsl_runner(x)