The CPU dummy weight broke torch.mm(compressor.weight.T) which expects GPU tensors. Instead, reduce max_model_len to fit KV cache within available memory (876544 instead of 1048576).
158 lines
6.5 KiB
Python
158 lines
6.5 KiB
Python
"""CuTeDSL NVFP4 Quantization Method for vLLM
|
|
|
|
Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM.
|
|
After process_weights_after_loading, the module's quant_method is swapped
|
|
to CuTeDSLNvfp4LinearMethod which routes forward() through CuTeDSL
|
|
via torch.library.custom_op (opaque to torch.compile).
|
|
"""
|
|
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.linear import LinearMethodBase
|
|
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
|
|
|
|
|
|
class CuTeDSLNvfp4Method(LinearMethodBase):
|
|
"""Pre-processing quant method that sets up CuTeDSL runners.
|
|
|
|
Installed on NVFP4 linear layers before process_weights_after_loading.
|
|
When vLLM calls process_weights_after_loading, this method:
|
|
1. Reads NVFP4 weights (uint8, float8 block scales, float32 global scales)
|
|
2. Converts to CuTeDSL format
|
|
3. Creates CuTeDSLNvfp4Linear runner
|
|
4. Stores runner on the module
|
|
5. Frees original weight/scale params
|
|
6. Replaces quant_method with CuTeDSLNvfp4LinearMethod
|
|
"""
|
|
|
|
def __init__(self, is_fused: bool = False):
|
|
"""
|
|
Args:
|
|
is_fused: True for MergedColumnParallelLinear with two sub-projections
|
|
(e.g., fused_wqa_wkv with q_a + kv, or gate_up_proj with gate + up).
|
|
Handles dual weight_scale_2 the same way as MoE L1:
|
|
normalize to max(gs1, gs2), fold ratio into block scales.
|
|
"""
|
|
self.is_fused = is_fused
|
|
|
|
def create_weights(self, layer, input_size_per_partition,
|
|
output_partition_sizes, input_size, output_size,
|
|
params_dtype, **extra_weight_attrs):
|
|
# We don't create weights — ModelOptNvFp4LinearMethod already did that.
|
|
# This method is only installed after weight loading.
|
|
pass
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
|
|
|
w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1
|
|
device = w_uint8.device
|
|
out_features = w_uint8.shape[0]
|
|
in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8
|
|
|
|
# Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N)
|
|
w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
|
|
|
# Block scales: (N, K_sf) → (K_sf, N)
|
|
sf = layer.weight_scale.data
|
|
if sf.dtype != torch.float8_e4m3fn:
|
|
sf = sf.to(torch.float8_e4m3fn)
|
|
sf = sf.permute(1, 0).contiguous()
|
|
|
|
# Global scale
|
|
weight_scale_2 = layer.weight_scale_2.data
|
|
if self.is_fused and weight_scale_2.numel() == 2:
|
|
# Dual global scales (fused_wqa_wkv: q_a + kv, gate_up: gate + up)
|
|
gs1 = weight_scale_2[0].item()
|
|
gs2 = weight_scale_2[1].item()
|
|
gs = max(gs1, gs2)
|
|
|
|
# Fold ratio into block scales via float32 round-trip
|
|
if gs1 != gs2:
|
|
sf_f32 = sf.float()
|
|
# After permute to (K_sf, N): first sub-projection's output
|
|
# columns, then second sub-projection's output columns
|
|
logical_widths = getattr(layer, 'logical_widths', None)
|
|
if logical_widths is not None and len(logical_widths) == 2:
|
|
split_point = logical_widths[0]
|
|
else:
|
|
split_point = out_features // 2
|
|
sf_f32[:, :split_point] *= (gs1 / gs)
|
|
sf_f32[:, split_point:] *= (gs2 / gs)
|
|
sf = sf_f32.to(torch.float8_e4m3fn)
|
|
else:
|
|
gs = weight_scale_2.max().item()
|
|
|
|
# Create CuTeDSL runner
|
|
runner = CuTeDSLNvfp4Linear(
|
|
in_features=in_features,
|
|
out_features=out_features,
|
|
device=device,
|
|
)
|
|
runner.fp4 = [w_fp4]
|
|
runner.sf = [sf]
|
|
runner.gs = [gs]
|
|
runner.finalize_weights()
|
|
|
|
# Register runner in global registry (for torch.library.custom_op)
|
|
layer._cutedsl_runner_id = register_runner(runner)
|
|
layer._cutedsl_out_features = out_features
|
|
|
|
# Warmup: compute activation global scale from sample data.
|
|
# The checkpoint's input_scale is a calibration-time value that does NOT
|
|
# match what quantize_activation_nvfp4 expects at runtime. Using it
|
|
# produces garbage output (empty EOS tokens). The correct approach is
|
|
# a warmup forward pass that measures the actual activation distribution.
|
|
# Use only 1 token to minimize GPU memory overhead during weight loading.
|
|
with torch.no_grad():
|
|
sample = torch.randn(1, in_features,
|
|
dtype=torch.bfloat16, device=device) * 2.0
|
|
runner.compute_activation_global_scale(sample)
|
|
del sample
|
|
torch.cuda.empty_cache()
|
|
|
|
# Replace weight with dummy BF16 (needed by vLLM module introspection)
|
|
# Replace weight with a GPU dummy (some vLLM code paths like
|
|
# torch.mm(compressor.weight.T) expect weight on GPU).
|
|
layer.weight = torch.nn.Parameter(
|
|
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
|
|
device=device),
|
|
requires_grad=False,
|
|
)
|
|
|
|
# Free original NVFP4 params
|
|
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
|
"input_global_scale", "input_global_scale_inv",
|
|
"weight_global_scale", "alpha", "weight_scale_inv"):
|
|
if hasattr(layer, attr):
|
|
try:
|
|
delattr(layer, attr)
|
|
except Exception:
|
|
pass
|
|
|
|
# Swap quant method to the forward-only one
|
|
layer.quant_method = CuTeDSLNvfp4LinearMethod()
|
|
|
|
def apply(self, layer, x, bias=None):
|
|
raise NotImplementedError(
|
|
"CuTeDSLNvfp4Method should be replaced by "
|
|
"CuTeDSLNvfp4LinearMethod after process_weights_after_loading"
|
|
)
|
|
|
|
|
|
class CuTeDSLNvfp4LinearMethod(LinearMethodBase):
|
|
"""Forward path: BF16 input → CuTeDSL NVFP4 GEMM → BF16 output."""
|
|
|
|
def create_weights(self, layer, input_size_per_partition,
|
|
output_partition_sizes, input_size, output_size,
|
|
params_dtype, **extra_weight_attrs):
|
|
pass
|
|
|
|
def apply(self, layer, x: torch.Tensor, bias=None) -> torch.Tensor:
|
|
result = nvfp4_linear_gemm(
|
|
x, layer._cutedsl_runner_id, layer._cutedsl_out_features,
|
|
)
|
|
if bias is not None:
|
|
result = result + bias
|
|
return result
|