Refactor NVFP4 Linear utils for ModelOpt and CT (#33201)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-01-30 19:37:42 -05:00
committed by GitHub
parent 2b465570e6
commit 67ebaff528
12 changed files with 462 additions and 483 deletions

View File

@@ -25,7 +25,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (

View File

@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (

View File

@@ -18,7 +18,6 @@ from compressed_tensors.quantization import (
)
from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -63,9 +62,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
)
from vllm.platforms import current_platform
if TYPE_CHECKING:
@@ -627,14 +623,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_nvfp4_format(weight_quant) and self._is_nvfp4_format(
input_quant
):
if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
return CompressedTensorsW4A4Fp4()
else:
logger.warning_once(
"Current platform does not support cutlass NVFP4."
" Running CompressedTensorsW4A16Fp4."
)
return CompressedTensorsW4A16Fp4(has_input_global_scale=True)
return CompressedTensorsW4A4Fp4()
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(

View File

@@ -81,7 +81,7 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
)
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer) -> None:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Rename weight_packed to weight that marlin expects
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
del layer.weight_packed
@@ -98,7 +98,7 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=None,
weight_global_scale=None,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,

View File

@@ -22,8 +22,7 @@ __all__ = ["CompressedTensorsW4A16Fp4"]
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
def __init__(self, has_input_global_scale: bool = False):
self.has_input_global_scale = has_input_global_scale
def __init__(self):
self.group_size = 16
@classmethod
@@ -79,30 +78,16 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale)
if self.has_input_global_scale:
input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_global_scale", input_global_scale)
def process_weights_after_loading(self, layer) -> None:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process parameters for marlin repacking
# Rename weight_packed to weight that marlin expects
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
del layer.weight_packed
# Rename weight_global_scale to weight_scale_2 that marlin expects
# Note: ct stores the inverse of what is expected by the marlin kernel
layer.weight_scale_2 = Parameter(
1 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False
# ct stores the inverse of what is expected by the marlin kernel
layer.weight_global_scale = Parameter(
1.0 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False
)
del layer.weight_global_scale
if self.has_input_global_scale:
layer.input_global_scale = torch.nn.Parameter(
layer.input_global_scale.data, requires_grad=False
)
prepare_fp4_layer_for_marlin(layer)
@@ -116,7 +101,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
weight_global_scale=layer.weight_global_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,

View File

@@ -5,75 +5,31 @@ from collections.abc import Callable
import torch
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
run_nvfp4_emulations,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale,
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
apply_nvfp4_linear,
convert_to_nvfp4_linear_kernel_format,
select_nvfp4_linear_backend,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
has_flashinfer,
)
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A4Fp4"]
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
def __init__(self):
self.backend = "none"
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
if has_flashinfer():
self.backend = "flashinfer-cutlass"
elif cutlass_fp4_supported():
self.backend = "cutlass"
elif envs.VLLM_USE_FBGEMM:
self.backend = "fbgemm"
try:
import fbgemm_gpu # noqa: F401
except ImportError as exc:
raise ImportError(
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
) from exc
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
self.backend = "cutlass"
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
if self.backend == "none":
raise ValueError(
"No valid NVFP4 GEMM backend found. "
"Please check your platform capability."
)
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
self.backend = select_nvfp4_linear_backend()
self.group_size = 16
@classmethod
def get_min_capability(cls) -> int:
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
return 80
return 100
return 75
def create_weights(
self,
@@ -129,120 +85,40 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
)
layer.register_parameter("input_global_scale", input_global_scale)
def process_weights_after_loading(self, layer) -> None:
global_input_scale = layer.input_global_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(global_input_scale, requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Rename CT checkpoint names to standardized names
layer.weight = layer.weight_packed
del layer.weight_packed
# Process global scales (CT stores as divisors, i.e. 1/scale)
input_global_scale_inv = layer.input_global_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(
(1.0 / input_global_scale_inv).to(torch.float32), requires_grad=False
)
weight_global_scale = layer.weight_global_scale.max().to(torch.float32)
layer.weight_global_scale = Parameter(
layer.weight_global_scale.max().to(torch.float32), requires_grad=False
1.0 / weight_global_scale, requires_grad=False
)
if self.backend == "flashinfer-trtllm":
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
weight = layer.weight_packed.data
weight_scale = layer.weight_scale.data
epilogue_tile_m = 128
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
weight_scale = (
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
.reshape(weight_scale.shape)
.view(torch.float8_e4m3fn)
)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
if self.backend == "fbgemm":
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
# Pad weights for CUTLASS/FlashInfer kernel alignment (K and N
# divisible by 32). fbgemm has its own layout requirements.
if self.backend in ("cutlass", "flashinfer-cutlass"):
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight_packed.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
layer.weights_padding_cols = 0
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False
)
# Pre-compute alpha and inverse for runtime quantization
layer.input_global_scale_inv = Parameter(
input_global_scale_inv, requires_grad=False
)
layer.alpha = Parameter(
1 / (layer.input_global_scale * layer.weight_global_scale),
requires_grad=False,
layer.input_global_scale * layer.weight_global_scale, requires_grad=False
)
# Convert layer to NVFP4 linear kernel format
convert_to_nvfp4_linear_kernel_format(self.backend, layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
out = run_nvfp4_emulations(
x=x,
input_global_scale=layer.input_global_scale,
weight=layer.weight_packed,
weight_scale_swizzled=layer.weight_scale,
weight_global_scale=layer.weight_global_scale,
)
if bias is not None:
out = out + bias
return out
output_dtype = x.dtype
output_size = layer.output_size_per_partition
output_shape = [*x.shape[:-1], output_size]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(
x,
layer.input_global_scale,
is_sf_swizzled_layout=True,
return apply_nvfp4_linear(
backend=self.backend,
layer=layer,
x=x,
bias=bias,
)
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = (
x_fp4,
layer.weight_packed,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
)
if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
elif self.backend == "fbgemm":
out = torch.ops.fbgemm.f4f4bf16(
x_fp4,
layer.weight_packed,
x_blockscale.view(-1).view(torch.uint8),
layer.weight_scale,
layer.alpha,
use_mx=False,
).to(output_dtype)
else:
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)

View File

@@ -5,12 +5,9 @@ from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe.config import (
@@ -66,24 +63,19 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin,
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
apply_nvfp4_linear,
convert_to_nvfp4_linear_kernel_format,
select_nvfp4_linear_backend,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
cutlass_fp4_supported,
is_layer_skipped,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
kNvfp4Dynamic,
kNvfp4Static,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
@@ -96,11 +88,6 @@ from vllm.model_executor.parameter import (
PerTensorScaleParameter,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
has_flashinfer,
)
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@@ -498,7 +485,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
def process_weights_after_loading(self, layer: Module) -> None:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight
max_w_scale = layer.weight_scale.max()
if not (layer.weight_scale == layer.weight_scale[0]).all():
@@ -580,7 +567,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: Module) -> None:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
@@ -681,7 +668,7 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: Module) -> None:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
layer.weight = Parameter(layer.weight.data, requires_grad=False)
@@ -1108,32 +1095,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
self.marlin_input_dtype = None
self.backend = "none"
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
if current_platform.has_device_capability(100) and has_flashinfer():
self.backend = "flashinfer-cutlass"
elif cutlass_fp4_supported():
self.backend = "cutlass"
elif is_fp4_marlin_supported():
self.backend = "marlin"
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
self.backend = "cutlass"
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
elif envs.VLLM_NVFP4_GEMM_BACKEND == "marlin":
self.backend = "marlin"
assert is_fp4_marlin_supported(), f"Marlin is required for {self.backend}"
if self.backend == "none":
raise ValueError(
"No valid NVFP4 GEMM backend found. "
"Please check your platform capability."
)
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
self.backend = select_nvfp4_linear_backend()
def create_weights(
self,
@@ -1181,19 +1143,19 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
)
layer.register_parameter("weight", weight)
# Input Weight Scale
input_scale = PerTensorScaleParameter(
# Input Global Scale
input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("input_scale", input_global_scale)
# Global Weight Scale
weight_scale_2 = PerTensorScaleParameter(
# Weight Global Scale
weight_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale_2", weight_scale_2)
layer.register_parameter("weight_scale_2", weight_global_scale)
# Per Block Weight Scale
weight_scale = ModelWeightParameter(
@@ -1209,65 +1171,25 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: Module) -> None:
# global scales:
input_scale_2 = layer.input_scale.max().to(torch.float32)
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Rename ModelOpt checkpoint names to standardized names
input_global_scale = layer.input_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
del layer.input_scale
weight_global_scale = layer.weight_scale_2.max().to(torch.float32)
layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
del layer.weight_scale_2
# Pre-compute alpha and inverse for runtime quantization
layer.alpha = Parameter(
layer.input_scale * layer.weight_scale_2, requires_grad=False
layer.input_global_scale * layer.weight_global_scale, requires_grad=False
)
layer.input_global_scale_inv = Parameter(
(1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False
)
# Calculate `1 / input_scale` so that we don't need to do so at runtime
layer.input_scale_inv = Parameter(
(1 / layer.input_scale).to(torch.float32), requires_grad=False
)
# Swizzle the weight blockscale.
# contracting dimension is input dimension
# block_size = 16;
assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
"Weight Block scale must be represented as FP8-E4M3"
)
if self.backend == "marlin":
prepare_fp4_layer_for_marlin(layer)
del layer.alpha
del layer.input_scale
elif self.backend == "flashinfer-trtllm":
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
weight = layer.weight.data
weight_scale = layer.weight_scale.data
epilogue_tile_m = 128
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
weight_scale = (
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
.reshape(weight_scale.shape)
.view(torch.float8_e4m3fn)
)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False)
else:
# Swizzle block scales and pad the packed NVFP4 weights for kernel
# alignment (CUTLASS/FlashInfer require K and N divisible by 32).
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight = Parameter(weight, requires_grad=False)
# Convert layer to NVFP4 linear kernel format
convert_to_nvfp4_linear_kernel_format(self.backend, layer)
def apply(
self,
@@ -1275,63 +1197,13 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.backend == "marlin":
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
input_dtype=self.marlin_input_dtype,
)
output_dtype = x.dtype
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(
x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend
return apply_nvfp4_linear(
backend=self.backend,
layer=layer,
x=x,
bias=bias,
)
# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale
assert x_fp4.dtype == torch.uint8
assert layer.weight.dtype == torch.uint8
assert x_blockscale.dtype == torch.float8_e4m3fn
assert layer.weight_scale.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
output_size = layer.output_size_per_partition
output_shape = [x.shape[0], output_size]
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = (
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
)
if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
else:
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"""

View File

@@ -15,11 +15,13 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
swizzle_blockscale,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (

View File

@@ -92,7 +92,7 @@ def apply_fp4_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor | None,
weight_global_scale: torch.Tensor | None,
workspace: torch.Tensor,
size_n: int,
size_k: int,
@@ -112,7 +112,7 @@ def apply_fp4_marlin_linear(
inputs = reshaped_x
a_scales = None
is_nvfp4 = weight_scale_2 is not None
is_nvfp4 = weight_global_scale is not None
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
@@ -128,7 +128,7 @@ def apply_fp4_marlin_linear(
b_bias=bias,
b_scales=weight_scale,
a_scales=a_scales,
global_scale=weight_scale_2,
global_scale=weight_global_scale,
b_zeros=None,
g_idx=None,
perm=None,
@@ -154,7 +154,7 @@ def prepare_fp4_layer_for_marlin(
"performance for compute-heavy workloads."
)
is_nvfp4 = hasattr(layer, "weight_scale_2")
is_nvfp4 = hasattr(layer, "weight_global_scale")
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
@@ -210,9 +210,11 @@ def prepare_fp4_layer_for_marlin(
weight_scale = nvfp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False)
weight_global_scale = layer.weight_global_scale.to(param_dtype)
weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale)
layer.weight_global_scale = torch.nn.Parameter(
weight_global_scale, requires_grad=False
)
else:
weight_scale = mxfp4_marlin_process_scales(
weight_scale, input_dtype=input_dtype

View File

@@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)

View File

@@ -0,0 +1,375 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
import vllm.envs as envs
from vllm._custom_ops import (
cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4,
scaled_fp4_quant,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
run_nvfp4_emulations,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
class NvFp4LinearBackend(Enum):
VLLM_CUTLASS = "cutlass"
FLASHINFER_CUTLASS = "flashinfer-cutlass"
FLASHINFER_TRTLLM = "flashinfer-trtllm"
FLASHINFER_CUDNN = "flashinfer-cudnn"
FBGEMM = "fbgemm"
MARLIN = "marlin"
EMULATION = "emulation"
def select_nvfp4_linear_backend() -> NvFp4LinearBackend:
"""
Select the best available NVFP4 GEMM backend based on environment
configuration and platform capabilities.
"""
backend: NvFp4LinearBackend | None = None
if envs.VLLM_USE_FBGEMM:
try:
import fbgemm_gpu # noqa: F401
except ImportError as exc:
raise ImportError(
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
) from exc
backend = NvFp4LinearBackend.FBGEMM
elif envs.VLLM_USE_NVFP4_CT_EMULATIONS:
backend = NvFp4LinearBackend.EMULATION
elif envs.VLLM_NVFP4_GEMM_BACKEND is None:
# Auto-select best available backend
if current_platform.has_device_capability(100) and has_flashinfer():
backend = NvFp4LinearBackend.FLASHINFER_CUTLASS
elif cutlass_fp4_supported():
backend = NvFp4LinearBackend.VLLM_CUTLASS
elif is_fp4_marlin_supported():
backend = NvFp4LinearBackend.MARLIN
else:
backend = NvFp4LinearBackend(envs.VLLM_NVFP4_GEMM_BACKEND)
# Validate that the backend is supported
if backend in (
NvFp4LinearBackend.FLASHINFER_CUTLASS,
NvFp4LinearBackend.FLASHINFER_TRTLLM,
NvFp4LinearBackend.FLASHINFER_CUDNN,
):
assert has_flashinfer(), f"FlashInfer is required for {backend}"
elif backend == NvFp4LinearBackend.VLLM_CUTLASS:
assert cutlass_fp4_supported(), f"Cutlass is required for {backend}"
elif backend == NvFp4LinearBackend.MARLIN:
assert is_fp4_marlin_supported(), f"Marlin is required for {backend}"
elif backend is None:
raise ValueError(
f"No NVFP4 GEMM backend selected, "
f"available backends: {list(NvFp4LinearBackend)}"
)
logger.info_once(f"Using {backend} for NVFP4 GEMM")
return backend
def prepare_weights_for_nvfp4_flashinfer_trtllm(
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare weights and scales for FlashInfer TRTLLM FP4 GEMM."""
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
epilogue_tile_m = 128
shuffled_weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
shuffled_weight_scale = (
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
.reshape(weight_scale.shape)
.view(torch.float8_e4m3fn)
)
return shuffled_weight, shuffled_weight_scale
def prepare_weights_for_nvfp4_cutlass(
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""
Prepare weights and scales for CUTLASS/FlashInfer-CUTLASS FP4 GEMM.
This involves padding weights for alignment (K and N divisible by 32)
"""
swizzled_weight_scale = swizzle_blockscale(weight_scale)
padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(weight)
return padded_weight, swizzled_weight_scale, weights_padding_cols
def prepare_weights_for_nvfp4_fbgemm(
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare weights and scales for FBGEMM FP4 GEMM."""
swizzled_weight_scale = swizzle_blockscale(weight_scale)
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
return weight, swizzled_weight_scale
def convert_to_nvfp4_linear_kernel_format(
backend: NvFp4LinearBackend,
layer: torch.nn.Module,
) -> None:
"""Convert layer to NVFP4 linear kernel format."""
assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
"Weight Block scale must be represented as FP8-E4M3"
)
# Default to no padding
layer.weights_padding_cols = 0
if backend == NvFp4LinearBackend.MARLIN:
prepare_fp4_layer_for_marlin(layer)
elif backend == NvFp4LinearBackend.FLASHINFER_TRTLLM:
weight, weight_scale = prepare_weights_for_nvfp4_flashinfer_trtllm(
layer.weight.data, layer.weight_scale.data
)
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
elif backend == NvFp4LinearBackend.FBGEMM:
weight, weight_scale = prepare_weights_for_nvfp4_fbgemm(
layer.weight.data, layer.weight_scale.data
)
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
elif (
backend == NvFp4LinearBackend.VLLM_CUTLASS
or backend == NvFp4LinearBackend.FLASHINFER_CUTLASS
):
weight, weight_scale, weights_padding_cols = prepare_weights_for_nvfp4_cutlass(
layer.weight.data, layer.weight_scale.data
)
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
layer.weights_padding_cols = weights_padding_cols
def apply_nvfp4_linear(
backend: NvFp4LinearBackend,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Apply NVFP4 linear transformation using the specified backend.
"""
weight = layer.weight
weight_scale = layer.weight_scale
weight_global_scale = layer.weight_global_scale
input_global_scale_inv = layer.input_global_scale_inv
alpha = layer.alpha
output_size = layer.output_size_per_partition
input_size = layer.input_size_per_partition
if backend == NvFp4LinearBackend.MARLIN:
return apply_fp4_marlin_linear(
input=x,
weight=weight,
weight_scale=weight_scale,
weight_global_scale=weight_global_scale,
workspace=layer.workspace,
size_n=output_size,
size_k=input_size,
bias=bias,
)
elif backend == NvFp4LinearBackend.EMULATION:
out = run_nvfp4_emulations(
x=x,
input_global_scale=input_global_scale_inv,
weight=weight,
weight_scale_swizzled=weight_scale,
weight_global_scale=weight_global_scale,
)
if bias is not None:
out = out + bias
return out
output_dtype = x.dtype
output_shape = [*x.shape[:-1], output_size]
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(
x, input_global_scale_inv, is_sf_swizzled_layout=True, backend=backend.value
)
# Validate dtypes
assert x_fp4.dtype == torch.uint8
assert weight.dtype == torch.uint8
assert x_blockscale.dtype == torch.float8_e4m3fn
# weight_scale is fp8 for most backends, but uint8 for fbgemm
assert weight_scale.dtype in (torch.float8_e4m3fn, torch.uint8)
assert alpha.dtype == torch.float32
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
# Prepare args for the matmul
mm_args = (
x_fp4,
weight,
x_blockscale,
weight_scale,
alpha,
output_dtype,
)
# Call the appropriate backend
if backend.value.startswith("flashinfer-"):
backend_name = backend.value[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
elif backend == NvFp4LinearBackend.FBGEMM:
out = torch.ops.fbgemm.f4f4bf16(
x_fp4,
weight,
x_blockscale.view(-1).view(torch.uint8),
weight_scale,
alpha,
use_mx=False,
).to(output_dtype)
else:
assert backend == NvFp4LinearBackend.VLLM_CUTLASS
out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
"""
Pad and block-interleave the FP4 block-scales so that they match the data
layout expected by the CUTLASS / FlashInfer kernels.
Parameters
----------
scale: torch.Tensor
Returns
-------
torch.Tensor
The swizzled tensor with the same logical shape as *scale*.
"""
assert scale.dtype == torch.float8_e4m3fn, (
"swizzle_blockscale expects the input tensor to be in "
"torch.float8_e4m3fn format."
)
scale_ndim = scale.ndim
if scale_ndim == 2:
scale = scale.unsqueeze(0) # (1, M, K)
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
B, M, K = scale.shape
M_padded = round_up(M, 128)
K_padded = round_up(K, 4)
padded = torch.zeros(
(B, M_padded, K_padded), dtype=scale.dtype, device=scale.device
)
padded[:B, :M, :K] = scale
# Reshape / permute to the layout required by the kernel.
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
if scale_ndim == 2:
return swizzled.reshape(M_padded, K_padded)
return swizzled.reshape(B, M_padded, K_padded)
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)
def pad_nvfp4_weight_for_cutlass(
weight: torch.Tensor,
alignment: int = 32,
) -> tuple[torch.Tensor, int]:
"""
Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy
the alignment constraints required by CUTLASS / FlashInfer FP4 kernels.
CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible
by 32 for aligned memory access and efficient tensor core operations.
"""
weight_current_rows = weight.shape[0]
# Pad N dimension (rows) if not aligned
if weight_current_rows % alignment != 0:
total_rows = round_up(weight_current_rows, alignment)
pad_rows = total_rows - weight_current_rows
weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_rows)).contiguous()
# Check K dimension alignment
# 2 FP4 items are packed per byte in the input dimension
weight_current_col_bytes = weight.shape[1]
weight_current_col_elements = weight_current_col_bytes * 2
weights_padding_bytes = 0
if weight_current_col_elements % alignment != 0:
total_cols = round_up(weight_current_col_elements, alignment)
pad_cols = total_cols - weight_current_col_elements
# Convert from FP4 element count to bytes (2 FP4 values per byte)
# pad_cols is always even since alignment=32 and current elements are even
pad_bytes = pad_cols // 2
weight = torch.nn.functional.pad(weight, (0, pad_bytes, 0, 0)).contiguous()
weights_padding_bytes = pad_bytes
return weight, weights_padding_bytes
def pad_nvfp4_activation_for_cutlass(
x_fp4: torch.Tensor,
weights_padding_bytes: int,
) -> torch.Tensor:
"""
Pad packed FP4 activations to match the K-dimension padding applied to weights.
The padding is in bytes (tensor dimension), not FP4 elements.
"""
if weights_padding_bytes > 0:
return torch.nn.functional.pad(x_fp4, (0, weights_padding_bytes)).contiguous()
return x_fp4
def slice_nvfp4_output(
out: torch.Tensor,
output_size: int,
) -> torch.Tensor:
"""
Slice the output tensor to remove padding in N dimension if weight was padded.
"""
if out.shape[-1] != output_size:
return out[..., :output_size].contiguous()
return out

View File

@@ -11,7 +11,6 @@ import numpy
import torch
from torch import fx
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@@ -768,60 +767,6 @@ def awq_pack(
return pack_cols(q_w, num_bits, size_k, size_n)
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
"""
Pad and block-interleave the FP4 block-scales so that they match the data
layout expected by the CUTLASS / FlashInfer kernels.
Parameters
----------
scale: torch.Tensor
Returns
-------
torch.Tensor
The swizzled tensor with the same logical shape as *scale*.
"""
assert scale.dtype == torch.float8_e4m3fn, (
"swizzle_blockscale expects the input tensor to be in "
"torch.float8_e4m3fn format."
)
scale_ndim = scale.ndim
if scale_ndim == 2:
scale = scale.unsqueeze(0) # (1, M, K)
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
B, M, K = scale.shape
def _round_up(x: int, m: int) -> int:
return (x + m - 1) // m * m
M_padded = _round_up(M, 128)
K_padded = _round_up(K, 4)
padded = torch.zeros(
(B, M_padded, K_padded), dtype=scale.dtype, device=scale.device
)
padded[:B, :M, :K] = scale
# Reshape / permute to the layout required by the kernel.
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
if scale_ndim == 2:
return swizzled.reshape(M_padded, K_padded)
return swizzled.reshape(B, M_padded, K_padded)
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)
def convert_bf16_scales_to_fp8(
quant_fp8: Callable, scales: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -868,70 +813,3 @@ def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tens
t |= ((nib - 8) & 0xF) << shift
return t
def round_up(x: int, m: int) -> int:
"""Round up x to the nearest multiple of m."""
return (x + m - 1) // m * m
def pad_nvfp4_weight_for_cutlass(
weight: torch.Tensor,
alignment: int = 32,
) -> tuple[torch.Tensor, int]:
"""
Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy
the alignment constraints required by CUTLASS / FlashInfer FP4 kernels.
CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible
by 32 for aligned memory access and efficient tensor core operations.
"""
weight_current_rows = weight.shape[0]
# Pad N dimension (rows) if not aligned
if weight_current_rows % alignment != 0:
total_rows = round_up(weight_current_rows, alignment)
pad_rows = total_rows - weight_current_rows
weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_rows)).contiguous()
# Check K dimension alignment
# 2 FP4 items are packed per byte in the input dimension
weight_current_col_bytes = weight.shape[1]
weight_current_col_elements = weight_current_col_bytes * 2
weights_padding_bytes = 0
if weight_current_col_elements % alignment != 0:
total_cols = round_up(weight_current_col_elements, alignment)
pad_cols = total_cols - weight_current_col_elements
# Convert from FP4 element count to bytes (2 FP4 values per byte)
# pad_cols is always even since alignment=32 and current elements are even
pad_bytes = pad_cols // 2
weight = torch.nn.functional.pad(weight, (0, pad_bytes, 0, 0)).contiguous()
weights_padding_bytes = pad_bytes
return weight, weights_padding_bytes
def pad_nvfp4_activation_for_cutlass(
x_fp4: torch.Tensor,
weights_padding_bytes: int,
) -> torch.Tensor:
"""
Pad packed FP4 activations to match the K-dimension padding applied to weights.
The padding is in bytes (tensor dimension), not FP4 elements.
"""
if weights_padding_bytes > 0:
return torch.nn.functional.pad(x_fp4, (0, weights_padding_bytes)).contiguous()
return x_fp4
def slice_nvfp4_output(
out: torch.Tensor,
output_size: int,
) -> torch.Tensor:
"""
Slice the output tensor to remove padding in N dimension if weight was padded.
"""
if out.shape[-1] != output_size:
return out[..., :output_size].contiguous()
return out