"""CuTeDSL NVFP4 Quantization Method for vLLM Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM. After process_weights_after_loading, the module's quant_method is swapped to CuTeDSLNvfp4LinearMethod which routes forward() through CuTeDSL via torch.library.custom_op (opaque to torch.compile). """ import torch from vllm.model_executor.layers.linear import LinearMethodBase from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm class CuTeDSLNvfp4Method(LinearMethodBase): """Pre-processing quant method that sets up CuTeDSL runners. Installed on NVFP4 linear layers before process_weights_after_loading. When vLLM calls process_weights_after_loading, this method: 1. Reads NVFP4 weights (uint8, float8 block scales, float32 global scales) 2. Converts to CuTeDSL format 3. Creates CuTeDSLNvfp4Linear runner 4. Stores runner on the module 5. Frees original weight/scale params 6. Replaces quant_method with CuTeDSLNvfp4LinearMethod """ def __init__(self, is_fused: bool = False): """ Args: is_fused: True for MergedColumnParallelLinear with two sub-projections (e.g., fused_wqa_wkv with q_a + kv, or gate_up_proj with gate + up). Handles dual weight_scale_2 the same way as MoE L1: normalize to max(gs1, gs2), fold ratio into block scales. """ self.is_fused = is_fused def create_weights(self, layer, input_size_per_partition, output_partition_sizes, input_size, output_size, params_dtype, **extra_weight_attrs): # We don't create weights — ModelOptNvFp4LinearMethod already did that. # This method is only installed after weight loading. pass def process_weights_after_loading(self, layer: torch.nn.Module) -> None: from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1 device = w_uint8.device out_features = w_uint8.shape[0] in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8 # Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N) w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() # Block scales: (N, K_sf) → (K_sf, N) sf = layer.weight_scale.data if sf.dtype != torch.float8_e4m3fn: sf = sf.to(torch.float8_e4m3fn) sf = sf.permute(1, 0).contiguous() # Global scale weight_scale_2 = layer.weight_scale_2.data if self.is_fused and weight_scale_2.numel() == 2: # Dual global scales (fused_wqa_wkv: q_a + kv, gate_up: gate + up) gs1 = weight_scale_2[0].item() gs2 = weight_scale_2[1].item() gs = max(gs1, gs2) # Fold ratio into block scales via float32 round-trip if gs1 != gs2: sf_f32 = sf.float() # After permute to (K_sf, N): first sub-projection's output # columns, then second sub-projection's output columns logical_widths = getattr(layer, 'logical_widths', None) if logical_widths is not None and len(logical_widths) == 2: split_point = logical_widths[0] else: split_point = out_features // 2 sf_f32[:, :split_point] *= (gs1 / gs) sf_f32[:, split_point:] *= (gs2 / gs) sf = sf_f32.to(torch.float8_e4m3fn) else: gs = weight_scale_2.max().item() # Create CuTeDSL runner runner = CuTeDSLNvfp4Linear( in_features=in_features, out_features=out_features, device=device, ) runner.fp4 = [w_fp4] runner.sf = [sf] runner.gs = [gs] runner.finalize_weights() # Register runner in global registry (for torch.library.custom_op) layer._cutedsl_runner_id = register_runner(runner) layer._cutedsl_out_features = out_features # Warmup: compute activation global scale from sample data with torch.no_grad(): sample = torch.randn(min(8, 256), in_features, dtype=torch.bfloat16, device=device) * 2.0 runner.compute_activation_global_scale(sample) # Replace weight with dummy BF16 (needed by vLLM module introspection) layer.weight = torch.nn.Parameter( torch.zeros(out_features, in_features, dtype=torch.bfloat16, device=device), requires_grad=False, ) # Free original NVFP4 params for attr in ("weight_scale", "weight_scale_2", "input_scale", "input_global_scale", "input_global_scale_inv", "weight_global_scale", "alpha", "weight_scale_inv"): if hasattr(layer, attr): try: delattr(layer, attr) except Exception: pass # Swap quant method to the forward-only one layer.quant_method = CuTeDSLNvfp4LinearMethod() def apply(self, layer, x, bias=None): raise NotImplementedError( "CuTeDSLNvfp4Method should be replaced by " "CuTeDSLNvfp4LinearMethod after process_weights_after_loading" ) class CuTeDSLNvfp4LinearMethod(LinearMethodBase): """Forward path: BF16 input → CuTeDSL NVFP4 GEMM → BF16 output.""" def create_weights(self, layer, input_size_per_partition, output_partition_sizes, input_size, output_size, params_dtype, **extra_weight_attrs): pass def apply(self, layer, x: torch.Tensor, bias=None) -> torch.Tensor: result = nvfp4_linear_gemm( x, layer._cutedsl_runner_id, layer._cutedsl_out_features, ) if bias is not None: result = result + bias return result