[Transform] [Quantization] Add QuTLASS support to vLLM (#24440)

Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Signed-off-by: Andrei Panferov <andrei@panferov.org>
Co-authored-by: Andrei Panferov <andrei@panferov.org>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Roberto L. Castro
2025-10-10 18:43:40 +02:00
committed by GitHub
parent 8d2b8c0ff2
commit 96ad65b7fe
12 changed files with 1848 additions and 1 deletions

View File

@@ -12,6 +12,7 @@ QuantizationMethods = Literal[
"fp8",
"ptpc_fp8",
"fbgemm_fp8",
"fp_quant",
"modelopt",
"modelopt_fp4",
"bitblas",
@@ -102,6 +103,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .experts_int8 import ExpertsInt8Config
from .fbgemm_fp8 import FBGEMMFp8Config
from .fp8 import Fp8Config
from .fp_quant import FPQuantConfig
from .gguf import GGUFConfig
from .gptq import GPTQConfig
from .gptq_bitblas import GPTQBitBLASConfig
@@ -125,6 +127,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"fp_quant": FPQuantConfig,
"modelopt": ModelOptFp8Config,
"modelopt_fp4": ModelOptNvFp4Config,
"bitblas": BitBLASConfig,

View File

@@ -0,0 +1,420 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
from vllm._custom_ops import (
cutlass_scaled_fp4_mm,
fusedQuantizeMx,
fusedQuantizeNv,
matmul_mxf4_bf16_tn,
)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
class FPQuantConfig(QuantizationConfig):
"""Config class for FPQuant."""
def __init__(
self,
hadamard_group_size: int = 32,
forward_dtype: str = "mxfp4",
forward_method: str = "abs_max",
pseudoquantization: bool = False,
modules_to_not_convert: Optional[list[str]] = None,
) -> None:
super().__init__()
self.hadamard_group_size = hadamard_group_size
self.forward_dtype = forward_dtype
self.forward_method = forward_method
self.pseudoquantization = pseudoquantization
self.modules_to_not_convert = modules_to_not_convert
if pseudoquantization:
raise ValueError("Pseudoquantization is not supported for vLLM")
def __repr__(self) -> str:
return (
f"FPQuantConfig(hadamard_group_size={self.hadamard_group_size}, "
f"forward_dtype={self.forward_dtype}, "
f"forward_method={self.forward_method}, "
f"pseudoquantization={self.pseudoquantization}, "
f"modules_to_not_convert={self.modules_to_not_convert})"
)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "fp_quant"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 100
@classmethod
def get_config_filenames(cls) -> list[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: dict[str, Any]) -> "FPQuantConfig":
hadamard_group_size = cls.get_from_keys(config, ["hadamard_group_size"])
forward_dtype = cls.get_from_keys(config, ["forward_dtype"])
forward_method = cls.get_from_keys(config, ["forward_method"])
pseudoquantization = cls.get_from_keys(config, ["pseudoquantization"])
modules_to_not_convert = cls.get_from_keys(config, ["modules_to_not_convert"])
return cls(
hadamard_group_size,
forward_dtype,
forward_method,
pseudoquantization,
modules_to_not_convert,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[LinearMethodBase]:
if self.modules_to_not_convert is not None and any(
prefix.endswith(module) for module in self.modules_to_not_convert
):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
return FPQuantLinearMethod(self)
return None
class FPQuantLinearMethod(LinearMethodBase):
"""Linear method for FPQuant.
Args:
quant_config: The FPQuant quantization config.
"""
def __init__(self, quant_config: FPQuantConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
del input_size # Unused.
if params_dtype != torch.bfloat16:
raise ValueError("Only bfloat16 is currently supported by FPQuant")
if input_size_per_partition % self.quant_config.hadamard_group_size != 0: # noqa: E501
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size. Or other skill issues."
)
assert self.quant_config.forward_dtype in ["mxfp4", "nvfp4"], (
"Only mxfp4 and nvfp4 are supported for now"
)
if self.quant_config.forward_dtype == "mxfp4":
group_size = 32
elif self.quant_config.forward_dtype == "nvfp4":
group_size = 16
else:
raise ValueError(
f"Unsupported forward_dtype: {self.quant_config.forward_dtype}"
)
qweight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": 2,
}
| extra_weight_attrs,
)
layer.register_parameter("qweight", qweight)
scales = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition // group_size,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": group_size,
}
| extra_weight_attrs,
)
layer.register_parameter("scales", scales)
weight_global_scale = Parameter(
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(
weight_global_scale, {"ignore_warning": True} | extra_weight_attrs
)
layer.register_parameter("weight_global_scale", weight_global_scale)
act_global_scale = Parameter(
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(
act_global_scale, {"ignore_warning": True} | extra_weight_attrs
)
layer.register_parameter("act_global_scale", act_global_scale)
forward_hadamard_matrix = Parameter(
torch.empty(
self.quant_config.hadamard_group_size,
self.quant_config.hadamard_group_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
forward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs
)
layer.register_parameter("forward_hadamard_matrix", forward_hadamard_matrix)
backward_hadamard_matrix = Parameter(
torch.empty(
self.quant_config.hadamard_group_size,
self.quant_config.hadamard_group_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
backward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs
)
layer.register_parameter("backward_hadamard_matrix", backward_hadamard_matrix)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return quantized_forward(
x,
layer.qweight,
layer.scales,
layer.weight_global_scale,
layer.act_global_scale,
bias,
layer.forward_hadamard_matrix,
self.quant_config.forward_method,
self.quant_config.forward_dtype,
)
def ceil_div(a, b):
return (a + b - 1) // b
def fused_quantize_mx(
x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, forward_method: str
) -> tuple[torch.Tensor, torch.Tensor]:
return fusedQuantizeMx(x_flat, hadamard_matrix, method=forward_method)
def fused_quantize_mx_fake(x_flat, hadamard_matrix, forward_method):
rows, cols = x_flat.size(0), x_flat.size(1) // 32
padded_rows = ((rows + 128 - 1) // 128) * 128
padded_cols = ((cols + 4 - 1) // 4) * 4
xh_e2m1 = torch.empty(
x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device
)
xh_e8m0 = torch.empty(
padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=x_flat.device
)
return xh_e2m1, xh_e8m0
direct_register_custom_op(
op_name="fused_quantize_mx",
op_func=fused_quantize_mx,
mutates_args=[],
fake_impl=fused_quantize_mx_fake,
dispatch_key=current_platform.dispatch_key,
)
def matmul_mxf4_bf16(
x: torch.Tensor,
w: torch.Tensor,
xs: torch.Tensor,
ws: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
return matmul_mxf4_bf16_tn(
x,
w,
to_blocked(xs, backend="triton").view(torch.float8_e8m0fnu),
to_blocked(ws, backend="triton").view(torch.float8_e8m0fnu),
alpha,
)
def matmul_mxf4_bf16_fake(x, w, xs, ws, alpha):
return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device)
direct_register_custom_op(
op_name="matmul_mxf4_bf16",
op_func=matmul_mxf4_bf16,
mutates_args=[],
fake_impl=matmul_mxf4_bf16_fake,
dispatch_key=current_platform.dispatch_key,
)
def fused_quantize_nv(
x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, global_scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
return fusedQuantizeNv(x_flat, hadamard_matrix, global_scale)
def fused_quantize_nv_fake(x_flat, hadamard_matrix, global_scale):
rows, cols = x_flat.size(0), x_flat.size(1) // 16
padded_rows = ((rows + 128 - 1) // 128) * 128
padded_cols = ((cols + 4 - 1) // 4) * 4
xh_e2m1 = torch.empty(
x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device
)
xh_e8m0 = torch.empty(
padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=x_flat.device
)
return xh_e2m1, xh_e8m0
direct_register_custom_op(
op_name="fused_quantize_nv",
op_func=fused_quantize_nv,
mutates_args=[],
fake_impl=fused_quantize_nv_fake,
dispatch_key=current_platform.dispatch_key,
)
def matmul_nvf4_bf16(
x: torch.Tensor,
w: torch.Tensor,
xs: torch.Tensor,
ws: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
return cutlass_scaled_fp4_mm(
x,
w,
to_blocked(xs, backend="triton")
.view(torch.float8_e4m3fn)
.view(-1, x.shape[1] // 8), # *2//16
to_blocked(ws, backend="triton")
.view(torch.float8_e4m3fn)
.view(-1, x.shape[1] // 8),
alpha,
torch.bfloat16,
)
def matmul_nvf4_bf16_fake(x, w, xs, ws, alpha):
return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device)
direct_register_custom_op(
op_name="matmul_nvf4_bf16",
op_func=matmul_nvf4_bf16,
mutates_args=[],
fake_impl=matmul_nvf4_bf16_fake,
dispatch_key=current_platform.dispatch_key,
)
def quantized_forward(
x: torch.Tensor,
qweight: torch.Tensor,
weight_scales: torch.Tensor,
weight_global_scale: torch.Tensor,
act_global_scale: torch.Tensor,
bias: Optional[torch.Tensor],
forward_hadamard_matrix: torch.Tensor,
forward_method: str,
forward_dtype: str,
) -> torch.Tensor:
x_flat = x.contiguous().flatten(end_dim=-2)
if forward_dtype == "mxfp4":
x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_mx(
x_flat, forward_hadamard_matrix, forward_method
)
y = torch.ops.vllm.matmul_mxf4_bf16(
x_flat_q,
qweight,
x_flat_scales,
weight_scales,
1 / (weight_global_scale * act_global_scale),
)
elif forward_dtype == "nvfp4":
x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_nv(
x_flat, forward_hadamard_matrix, act_global_scale
)
y = torch.ops.vllm.matmul_nvf4_bf16(
x_flat_q,
qweight,
x_flat_scales,
weight_scales,
1 / (weight_global_scale * act_global_scale),
)
else:
raise ValueError(f"Unsupported forward_dtype: {forward_dtype}")
y = y.view(*x.shape[:-1], y.shape[-1])
if bias is not None:
y += bias
return y

View File

@@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Modified by Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
#
# Copied from https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Literal
import torch
import triton
import triton.language as tl
from torch.library import wrap_triton
@triton.jit
def triton_scale_swizzle(
scale_ptr: torch.Tensor,
scale_rows: int,
scale_cols: int,
output_ptr: torch.Tensor,
input_row_stride: int,
output_block_stride: int,
BLOCK_ROWS: tl.constexpr,
BLOCK_COLS: tl.constexpr,
):
"""
Rearranges tensor data from row-major to block-scaled swizzle format.
Args:
scale_ptr: Pointer to the input scale tensor
scale_rows: Number of rows in the scale tensor
scale_cols: Number of columns in the scale tensor
output_ptr: Pointer to the output tensor
input_row_stride: Stride between rows in the input tensor
output_block_stride: Stride between blocks in the output tensor
BLOCK_ROWS: Number of rows in a tile (compile-time constant)
BLOCK_COLS: Number of columns in a tile (compile-time constant)
"""
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)
rows = tl.arange(0, BLOCK_ROWS)[:, None]
cols = tl.arange(0, BLOCK_COLS)[None, :]
# Calculate starting row and column for this tile
start_row = pid_row * BLOCK_ROWS
start_col = pid_col * BLOCK_COLS
global_rows = start_row + rows
global_cols = start_col + cols
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
input_scales = tl.load(
scale_ptr + global_rows * input_row_stride + global_cols,
mask=mask,
other=0.0,
)
r_div_32 = rows // 32
r_mod_32 = rows % 32
# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
# Flatten
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))
# Calculate block offset using provided output block stride
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)
tl.store(
output_ptr + block_offset + dest_indices_flat,
scales_flat,
)
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
"""
Rearranges an E8M0 tensor scale from row-major format to
block-scaled swizzle format.
This format is suitable for Tmem as described in NVIDIA documentation:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
scale_tensor: Input tensor in row-major format with 8-bit elements
Returns:
Rearranged tensor in block-scaled swizzle format
"""
assert scale_tensor.element_size() == 1, (
"Expected element size to be 1 byte (8 bits)"
)
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
rows, cols = scale_tensor.shape
# Calculate blocks needed
n_row_blocks = triton.cdiv(rows, 128)
n_col_blocks = triton.cdiv(cols, 4)
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
out = scale_tensor.new_empty((padded_rows, padded_cols))
# Input stride (for row-major format)
input_row_stride = cols
# We probably want handle multiple blocks per tile but
# for now keep it simple
BLOCK_ROWS, BLOCK_COLS = 128, 4
# Output block stride for the rearranged format
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
grid = lambda META: (
triton.cdiv(padded_rows, BLOCK_ROWS),
triton.cdiv(padded_cols, BLOCK_COLS),
)
wrap_triton(triton_scale_swizzle)[grid](
scale_tensor.view(torch.uint8),
rows,
cols,
out.view(torch.uint8),
input_row_stride,
output_block_stride,
BLOCK_ROWS=BLOCK_ROWS,
BLOCK_COLS=BLOCK_COLS,
)
return out
def ceil_div(a, b):
return (a + b - 1) // b
def to_blocked(
input_matrix: torch.Tensor, backend: Literal["torch", "triton"] = "triton"
) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying
the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
backend: "torch" (PyTorch path) or "triton" (Triton kernel)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
if backend == "triton":
return triton_mx_block_rearrange(input_matrix).flatten()
elif backend != "torch":
raise ValueError(f'backend must be "torch" or "triton", got {backend!r}')
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
padded = input_matrix
assert (rows, cols) == (padded_rows, padded_cols)
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
return rearranged.flatten()