From 67fe677c53e7a5f5d27a4f07edd66f44bea0495c Mon Sep 17 00:00:00 2001 From: danielafrimi <45691845+danielafrimi@users.noreply.github.com> Date: Mon, 26 Jan 2026 20:04:20 +0200 Subject: [PATCH] [FIX] Always support TP > 4 for FP4 Gemm (#31099) Signed-off-by: dafrimi Co-authored-by: root --- .../schemes/compressed_tensors_w4a4_nvfp4.py | 30 +++++++-- .../layers/quantization/modelopt.py | 23 ++++++- .../layers/quantization/utils/quant_utils.py | 67 +++++++++++++++++++ 3 files changed, 114 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 762498378..7ed769bad 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -16,6 +16,9 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( cutlass_fp4_supported, + pad_nvfp4_activation_for_cutlass, + pad_nvfp4_weight_for_cutlass, + slice_nvfp4_output, swizzle_blockscale, ) from vllm.model_executor.parameter import ( @@ -159,9 +162,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): if self.backend == "fbgemm": swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) - layer.weight_packed = Parameter( - layer.weight_packed.data, requires_grad=False - ) + + # Pad weights for CUTLASS/FlashInfer kernel alignment (K and N + # divisible by 32). fbgemm has its own layout requirements. + if self.backend in ("cutlass", "flashinfer-cutlass"): + weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass( + layer.weight_packed.data + ) + layer.weights_padding_cols = weights_padding_cols + layer.weight_packed = Parameter(weight, requires_grad=False) + else: + layer.weights_padding_cols = 0 + layer.weight_packed = Parameter( + layer.weight_packed.data, requires_grad=False + ) layer.alpha = Parameter( 1 / (layer.input_global_scale * layer.weight_global_scale), @@ -187,7 +201,8 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): return out output_dtype = x.dtype - output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]] + output_size = layer.output_size_per_partition + output_shape = [*x.shape[:-1], output_size] # quantize BF16 or FP16 to (FP4 and interleaved block scale) x_fp4, x_blockscale = scaled_fp4_quant( @@ -197,6 +212,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): backend=self.backend, ) + # Pad activations to match weight K-dimension padding + weights_padding_cols = getattr(layer, "weights_padding_cols", 0) + x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols) + mm_args = ( x_fp4, layer.weight_packed, @@ -221,6 +240,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): assert self.backend == "cutlass" out = cutlass_scaled_fp4_mm(*mm_args) + # Slice output to remove N-dimension padding + out = slice_nvfp4_output(out, output_size) + if bias is not None: out = out + bias return out.view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index f26aa045b..476ad618e 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -84,6 +84,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTokenSym, kNvfp4Dynamic, kNvfp4Static, + pad_nvfp4_activation_for_cutlass, + pad_nvfp4_weight_for_cutlass, + slice_nvfp4_output, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -1280,9 +1283,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False) else: + # Swizzle block scales and pad the packed NVFP4 weights for kernel + # alignment (CUTLASS/FlashInfer require K and N divisible by 32). swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) - layer.weight = Parameter(layer.weight.data, requires_grad=False) + + weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass( + layer.weight.data + ) + layer.weights_padding_cols = weights_padding_cols + layer.weight = Parameter(weight, requires_grad=False) def apply( self, @@ -1304,7 +1314,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ) output_dtype = x.dtype - output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) x_fp4, x_blockscale = scaled_fp4_quant( @@ -1319,6 +1328,12 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): assert layer.weight_scale.dtype == torch.float8_e4m3fn assert layer.alpha.dtype == torch.float32 + # Pad activations to match weight K-dimension padding + weights_padding_cols = getattr(layer, "weights_padding_cols", 0) + output_size = layer.output_size_per_partition + output_shape = [x.shape[0], output_size] + x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols) + mm_args = ( x_fp4, layer.weight, @@ -1327,6 +1342,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): layer.alpha, output_dtype, ) + if self.backend.startswith("flashinfer-"): backend_name = self.backend[len("flashinfer-") :] out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) @@ -1334,6 +1350,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): assert self.backend == "cutlass" out = cutlass_scaled_fp4_mm(*mm_args) + # Slice output to remove N-dimension padding + out = slice_nvfp4_output(out, output_size) + if bias is not None: out = out + bias return out.view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 727315106..5dbd05f16 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -868,3 +868,70 @@ def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tens t |= ((nib - 8) & 0xF) << shift return t + + +def round_up(x: int, m: int) -> int: + """Round up x to the nearest multiple of m.""" + return (x + m - 1) // m * m + + +def pad_nvfp4_weight_for_cutlass( + weight: torch.Tensor, + alignment: int = 32, +) -> tuple[torch.Tensor, int]: + """ + Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy + the alignment constraints required by CUTLASS / FlashInfer FP4 kernels. + + CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible + by 32 for aligned memory access and efficient tensor core operations. + """ + weight_current_rows = weight.shape[0] + + # Pad N dimension (rows) if not aligned + if weight_current_rows % alignment != 0: + total_rows = round_up(weight_current_rows, alignment) + pad_rows = total_rows - weight_current_rows + weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_rows)).contiguous() + + # Check K dimension alignment + # 2 FP4 items are packed per byte in the input dimension + weight_current_col_bytes = weight.shape[1] + weight_current_col_elements = weight_current_col_bytes * 2 + + weights_padding_bytes = 0 + if weight_current_col_elements % alignment != 0: + total_cols = round_up(weight_current_col_elements, alignment) + pad_cols = total_cols - weight_current_col_elements + # Convert from FP4 element count to bytes (2 FP4 values per byte) + # pad_cols is always even since alignment=32 and current elements are even + pad_bytes = pad_cols // 2 + weight = torch.nn.functional.pad(weight, (0, pad_bytes, 0, 0)).contiguous() + weights_padding_bytes = pad_bytes + + return weight, weights_padding_bytes + + +def pad_nvfp4_activation_for_cutlass( + x_fp4: torch.Tensor, + weights_padding_bytes: int, +) -> torch.Tensor: + """ + Pad packed FP4 activations to match the K-dimension padding applied to weights. + The padding is in bytes (tensor dimension), not FP4 elements. + """ + if weights_padding_bytes > 0: + return torch.nn.functional.pad(x_fp4, (0, weights_padding_bytes)).contiguous() + return x_fp4 + + +def slice_nvfp4_output( + out: torch.Tensor, + output_size: int, +) -> torch.Tensor: + """ + Slice the output tensor to remove padding in N dimension if weight was padded. + """ + if out.shape[-1] != output_size: + return out[..., :output_size].contiguous() + return out