[FIX] Always support TP > 4 for FP4 Gemm (#31099)

Signed-off-by: dafrimi <dafrimi@nvidia.com>
Co-authored-by: root <root@gpu-51.slurm-workers-slurm.slurm.svc.cluster.local>
This commit is contained in:
danielafrimi
2026-01-26 20:04:20 +02:00
committed by GitHub
parent d56afd45fd
commit 67fe677c53
3 changed files with 114 additions and 6 deletions

View File

@@ -16,6 +16,9 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported, cutlass_fp4_supported,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
@@ -159,9 +162,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
if self.backend == "fbgemm": if self.backend == "fbgemm":
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8) swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False # Pad weights for CUTLASS/FlashInfer kernel alignment (K and N
) # divisible by 32). fbgemm has its own layout requirements.
if self.backend in ("cutlass", "flashinfer-cutlass"):
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight_packed.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
layer.weights_padding_cols = 0
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False
)
layer.alpha = Parameter( layer.alpha = Parameter(
1 / (layer.input_global_scale * layer.weight_global_scale), 1 / (layer.input_global_scale * layer.weight_global_scale),
@@ -187,7 +201,8 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
return out return out
output_dtype = x.dtype output_dtype = x.dtype
output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]] output_size = layer.output_size_per_partition
output_shape = [*x.shape[:-1], output_size]
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant( x_fp4, x_blockscale = scaled_fp4_quant(
@@ -197,6 +212,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
backend=self.backend, backend=self.backend,
) )
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = ( mm_args = (
x_fp4, x_fp4,
layer.weight_packed, layer.weight_packed,
@@ -221,6 +240,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
assert self.backend == "cutlass" assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args) out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.view(*output_shape) return out.view(*output_shape)

View File

@@ -84,6 +84,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTokenSym, kFp8StaticTokenSym,
kNvfp4Dynamic, kNvfp4Dynamic,
kNvfp4Static, kNvfp4Static,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@@ -1280,9 +1283,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False)
else: else:
# Swizzle block scales and pad the packed NVFP4 weights for kernel
# alignment (CUTLASS/FlashInfer require K and N divisible by 32).
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight = Parameter(weight, requires_grad=False)
def apply( def apply(
self, self,
@@ -1304,7 +1314,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
) )
output_dtype = x.dtype output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant( x_fp4, x_blockscale = scaled_fp4_quant(
@@ -1319,6 +1328,12 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert layer.weight_scale.dtype == torch.float8_e4m3fn assert layer.weight_scale.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32 assert layer.alpha.dtype == torch.float32
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
output_size = layer.output_size_per_partition
output_shape = [x.shape[0], output_size]
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = ( mm_args = (
x_fp4, x_fp4,
layer.weight, layer.weight,
@@ -1327,6 +1342,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.alpha, layer.alpha,
output_dtype, output_dtype,
) )
if self.backend.startswith("flashinfer-"): if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :] backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
@@ -1334,6 +1350,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert self.backend == "cutlass" assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args) out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.view(*output_shape) return out.view(*output_shape)

View File

@@ -868,3 +868,70 @@ def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tens
t |= ((nib - 8) & 0xF) << shift t |= ((nib - 8) & 0xF) << shift
return t return t
def round_up(x: int, m: int) -> int:
"""Round up x to the nearest multiple of m."""
return (x + m - 1) // m * m
def pad_nvfp4_weight_for_cutlass(
weight: torch.Tensor,
alignment: int = 32,
) -> tuple[torch.Tensor, int]:
"""
Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy
the alignment constraints required by CUTLASS / FlashInfer FP4 kernels.
CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible
by 32 for aligned memory access and efficient tensor core operations.
"""
weight_current_rows = weight.shape[0]
# Pad N dimension (rows) if not aligned
if weight_current_rows % alignment != 0:
total_rows = round_up(weight_current_rows, alignment)
pad_rows = total_rows - weight_current_rows
weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_rows)).contiguous()
# Check K dimension alignment
# 2 FP4 items are packed per byte in the input dimension
weight_current_col_bytes = weight.shape[1]
weight_current_col_elements = weight_current_col_bytes * 2
weights_padding_bytes = 0
if weight_current_col_elements % alignment != 0:
total_cols = round_up(weight_current_col_elements, alignment)
pad_cols = total_cols - weight_current_col_elements
# Convert from FP4 element count to bytes (2 FP4 values per byte)
# pad_cols is always even since alignment=32 and current elements are even
pad_bytes = pad_cols // 2
weight = torch.nn.functional.pad(weight, (0, pad_bytes, 0, 0)).contiguous()
weights_padding_bytes = pad_bytes
return weight, weights_padding_bytes
def pad_nvfp4_activation_for_cutlass(
x_fp4: torch.Tensor,
weights_padding_bytes: int,
) -> torch.Tensor:
"""
Pad packed FP4 activations to match the K-dimension padding applied to weights.
The padding is in bytes (tensor dimension), not FP4 elements.
"""
if weights_padding_bytes > 0:
return torch.nn.functional.pad(x_fp4, (0, weights_padding_bytes)).contiguous()
return x_fp4
def slice_nvfp4_output(
out: torch.Tensor,
output_size: int,
) -> torch.Tensor:
"""
Slice the output tensor to remove padding in N dimension if weight was padded.
"""
if out.shape[-1] != output_size:
return out[..., :output_size].contiguous()
return out