[FIX] Always support TP > 4 for FP4 Gemm (#31099)
Signed-off-by: dafrimi <dafrimi@nvidia.com> Co-authored-by: root <root@gpu-51.slurm-workers-slurm.slurm.svc.cluster.local>
This commit is contained in:
@@ -16,6 +16,9 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
cutlass_fp4_supported,
|
cutlass_fp4_supported,
|
||||||
|
pad_nvfp4_activation_for_cutlass,
|
||||||
|
pad_nvfp4_weight_for_cutlass,
|
||||||
|
slice_nvfp4_output,
|
||||||
swizzle_blockscale,
|
swizzle_blockscale,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import (
|
from vllm.model_executor.parameter import (
|
||||||
@@ -159,9 +162,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
if self.backend == "fbgemm":
|
if self.backend == "fbgemm":
|
||||||
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
|
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
|
||||||
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
|
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
|
||||||
layer.weight_packed = Parameter(
|
|
||||||
layer.weight_packed.data, requires_grad=False
|
# Pad weights for CUTLASS/FlashInfer kernel alignment (K and N
|
||||||
)
|
# divisible by 32). fbgemm has its own layout requirements.
|
||||||
|
if self.backend in ("cutlass", "flashinfer-cutlass"):
|
||||||
|
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
|
||||||
|
layer.weight_packed.data
|
||||||
|
)
|
||||||
|
layer.weights_padding_cols = weights_padding_cols
|
||||||
|
layer.weight_packed = Parameter(weight, requires_grad=False)
|
||||||
|
else:
|
||||||
|
layer.weights_padding_cols = 0
|
||||||
|
layer.weight_packed = Parameter(
|
||||||
|
layer.weight_packed.data, requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
layer.alpha = Parameter(
|
layer.alpha = Parameter(
|
||||||
1 / (layer.input_global_scale * layer.weight_global_scale),
|
1 / (layer.input_global_scale * layer.weight_global_scale),
|
||||||
@@ -187,7 +201,8 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
output_dtype = x.dtype
|
output_dtype = x.dtype
|
||||||
output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]]
|
output_size = layer.output_size_per_partition
|
||||||
|
output_shape = [*x.shape[:-1], output_size]
|
||||||
|
|
||||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||||
x_fp4, x_blockscale = scaled_fp4_quant(
|
x_fp4, x_blockscale = scaled_fp4_quant(
|
||||||
@@ -197,6 +212,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
backend=self.backend,
|
backend=self.backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pad activations to match weight K-dimension padding
|
||||||
|
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
|
||||||
|
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
|
||||||
|
|
||||||
mm_args = (
|
mm_args = (
|
||||||
x_fp4,
|
x_fp4,
|
||||||
layer.weight_packed,
|
layer.weight_packed,
|
||||||
@@ -221,6 +240,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
assert self.backend == "cutlass"
|
assert self.backend == "cutlass"
|
||||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||||
|
|
||||||
|
# Slice output to remove N-dimension padding
|
||||||
|
out = slice_nvfp4_output(out, output_size)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
out = out + bias
|
out = out + bias
|
||||||
return out.view(*output_shape)
|
return out.view(*output_shape)
|
||||||
|
|||||||
@@ -84,6 +84,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
kFp8StaticTokenSym,
|
kFp8StaticTokenSym,
|
||||||
kNvfp4Dynamic,
|
kNvfp4Dynamic,
|
||||||
kNvfp4Static,
|
kNvfp4Static,
|
||||||
|
pad_nvfp4_activation_for_cutlass,
|
||||||
|
pad_nvfp4_weight_for_cutlass,
|
||||||
|
slice_nvfp4_output,
|
||||||
swizzle_blockscale,
|
swizzle_blockscale,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
@@ -1280,9 +1283,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
layer.weight = Parameter(weight, requires_grad=False)
|
layer.weight = Parameter(weight, requires_grad=False)
|
||||||
else:
|
else:
|
||||||
|
# Swizzle block scales and pad the packed NVFP4 weights for kernel
|
||||||
|
# alignment (CUTLASS/FlashInfer require K and N divisible by 32).
|
||||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||||
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
|
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
|
||||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
|
||||||
|
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
|
||||||
|
layer.weight.data
|
||||||
|
)
|
||||||
|
layer.weights_padding_cols = weights_padding_cols
|
||||||
|
layer.weight = Parameter(weight, requires_grad=False)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@@ -1304,7 +1314,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
output_dtype = x.dtype
|
output_dtype = x.dtype
|
||||||
output_shape = [x.shape[0], layer.weight.shape[0]]
|
|
||||||
|
|
||||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||||
x_fp4, x_blockscale = scaled_fp4_quant(
|
x_fp4, x_blockscale = scaled_fp4_quant(
|
||||||
@@ -1319,6 +1328,12 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
assert layer.weight_scale.dtype == torch.float8_e4m3fn
|
assert layer.weight_scale.dtype == torch.float8_e4m3fn
|
||||||
assert layer.alpha.dtype == torch.float32
|
assert layer.alpha.dtype == torch.float32
|
||||||
|
|
||||||
|
# Pad activations to match weight K-dimension padding
|
||||||
|
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
|
||||||
|
output_size = layer.output_size_per_partition
|
||||||
|
output_shape = [x.shape[0], output_size]
|
||||||
|
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
|
||||||
|
|
||||||
mm_args = (
|
mm_args = (
|
||||||
x_fp4,
|
x_fp4,
|
||||||
layer.weight,
|
layer.weight,
|
||||||
@@ -1327,6 +1342,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
layer.alpha,
|
layer.alpha,
|
||||||
output_dtype,
|
output_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.backend.startswith("flashinfer-"):
|
if self.backend.startswith("flashinfer-"):
|
||||||
backend_name = self.backend[len("flashinfer-") :]
|
backend_name = self.backend[len("flashinfer-") :]
|
||||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||||
@@ -1334,6 +1350,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
assert self.backend == "cutlass"
|
assert self.backend == "cutlass"
|
||||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||||
|
|
||||||
|
# Slice output to remove N-dimension padding
|
||||||
|
out = slice_nvfp4_output(out, output_size)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
out = out + bias
|
out = out + bias
|
||||||
return out.view(*output_shape)
|
return out.view(*output_shape)
|
||||||
|
|||||||
@@ -868,3 +868,70 @@ def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tens
|
|||||||
t |= ((nib - 8) & 0xF) << shift
|
t |= ((nib - 8) & 0xF) << shift
|
||||||
|
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def round_up(x: int, m: int) -> int:
|
||||||
|
"""Round up x to the nearest multiple of m."""
|
||||||
|
return (x + m - 1) // m * m
|
||||||
|
|
||||||
|
|
||||||
|
def pad_nvfp4_weight_for_cutlass(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
alignment: int = 32,
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
"""
|
||||||
|
Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy
|
||||||
|
the alignment constraints required by CUTLASS / FlashInfer FP4 kernels.
|
||||||
|
|
||||||
|
CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible
|
||||||
|
by 32 for aligned memory access and efficient tensor core operations.
|
||||||
|
"""
|
||||||
|
weight_current_rows = weight.shape[0]
|
||||||
|
|
||||||
|
# Pad N dimension (rows) if not aligned
|
||||||
|
if weight_current_rows % alignment != 0:
|
||||||
|
total_rows = round_up(weight_current_rows, alignment)
|
||||||
|
pad_rows = total_rows - weight_current_rows
|
||||||
|
weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_rows)).contiguous()
|
||||||
|
|
||||||
|
# Check K dimension alignment
|
||||||
|
# 2 FP4 items are packed per byte in the input dimension
|
||||||
|
weight_current_col_bytes = weight.shape[1]
|
||||||
|
weight_current_col_elements = weight_current_col_bytes * 2
|
||||||
|
|
||||||
|
weights_padding_bytes = 0
|
||||||
|
if weight_current_col_elements % alignment != 0:
|
||||||
|
total_cols = round_up(weight_current_col_elements, alignment)
|
||||||
|
pad_cols = total_cols - weight_current_col_elements
|
||||||
|
# Convert from FP4 element count to bytes (2 FP4 values per byte)
|
||||||
|
# pad_cols is always even since alignment=32 and current elements are even
|
||||||
|
pad_bytes = pad_cols // 2
|
||||||
|
weight = torch.nn.functional.pad(weight, (0, pad_bytes, 0, 0)).contiguous()
|
||||||
|
weights_padding_bytes = pad_bytes
|
||||||
|
|
||||||
|
return weight, weights_padding_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def pad_nvfp4_activation_for_cutlass(
|
||||||
|
x_fp4: torch.Tensor,
|
||||||
|
weights_padding_bytes: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad packed FP4 activations to match the K-dimension padding applied to weights.
|
||||||
|
The padding is in bytes (tensor dimension), not FP4 elements.
|
||||||
|
"""
|
||||||
|
if weights_padding_bytes > 0:
|
||||||
|
return torch.nn.functional.pad(x_fp4, (0, weights_padding_bytes)).contiguous()
|
||||||
|
return x_fp4
|
||||||
|
|
||||||
|
|
||||||
|
def slice_nvfp4_output(
|
||||||
|
out: torch.Tensor,
|
||||||
|
output_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Slice the output tensor to remove padding in N dimension if weight was padded.
|
||||||
|
"""
|
||||||
|
if out.shape[-1] != output_size:
|
||||||
|
return out[..., :output_size].contiguous()
|
||||||
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user