[NVFP4] Support NVFP4 dense models from modelopt and compressed-tensors on AMD Instinct MI300, MI355X and Hopper through emulation (#35733)

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
fxmarty-amd
2026-04-07 00:18:27 +02:00
committed by GitHub
parent 9c81f35b1a
commit 00d7b497b3
10 changed files with 191 additions and 58 deletions

View File

@@ -89,22 +89,33 @@ def test_models(example_prompts, model_name) -> None:
EAGER = [True, False]
SM_100_NVFP4_BACKENDS = [
"flashinfer-cudnn",
"flashinfer-trtllm",
"flashinfer-cutlass",
]
@pytest.mark.skipif(
not current_platform.has_device_capability(100),
reason="modelopt_fp4 is not supported on this GPU type.",
)
@pytest.mark.parametrize("model", ["nvidia/Llama-3.1-8B-Instruct-NVFP4"])
@pytest.mark.parametrize("eager", EAGER)
@pytest.mark.parametrize(
"backend",
[
"emulation",
"flashinfer-cudnn",
"flashinfer-trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used
"flashinfer-cutlass",
],
)
def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch):
if (
not current_platform.has_device_capability(100)
and backend in SM_100_NVFP4_BACKENDS
):
pytest.skip(
f"The backend {backend} is not supported with current_platform.has_device_capability(100) == False"
)
monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend)
with vllm_runner(model, enforce_eager=eager) as llm:
output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)

View File

@@ -366,9 +366,6 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
assert output
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize(
"args",
[
@@ -398,7 +395,7 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
assert qkv_proj.scheme.group_size == 16
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4)
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
print(output)
assert output

View File

@@ -1464,6 +1464,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
# - "marlin": use marlin GEMM backend (for GPUs without native FP4 support)
# - "emulation":
# use BF16/FP16 GEMM, dequantizing weights and running QDQ on activations.
# This is only meant for research purposes to run on devices where NVFP4
# GEMM kernels are not available.
# - <none>: automatically pick an available backend
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
"VLLM_NVFP4_GEMM_BACKEND",
@@ -1474,6 +1478,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"flashinfer-cutlass",
"cutlass",
"marlin",
"emulation",
],
),
# Controls garbage collection during CUDA graph capture.

View File

@@ -5,10 +5,12 @@ from collections.abc import Callable
import torch
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
NvFp4LinearBackend,
apply_nvfp4_linear,
convert_to_nvfp4_linear_kernel_format,
select_nvfp4_linear_backend,
@@ -19,6 +21,9 @@ from vllm.model_executor.parameter import (
PerTensorScaleParameter,
)
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A4Fp4"]
@@ -27,6 +32,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
self.backend = select_nvfp4_linear_backend()
self.group_size = 16
self.swizzle = None
if self.backend == NvFp4LinearBackend.EMULATION:
self.swizzle = False
@classmethod
def get_min_capability(cls) -> int:
return 75
@@ -89,6 +98,19 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
# Rename CT checkpoint names to standardized names
layer.weight = layer.weight_packed
del layer.weight_packed
if (
torch.unique(layer.input_global_scale).numel() != 1
or torch.unique(layer.weight_global_scale).numel() != 1
):
logger.warning_once(
"In NVFP4 linear, the global scale for input or weight are different"
" for parallel layers (e.g. q_proj, k_proj, v_proj). This "
" will likely result in reduced accuracy. Please verify the model"
" accuracy. Consider using a checkpoint with a shared global NVFP4"
" scale for fused layers."
)
# Process global scales (CT stores as divisors, i.e. 1/scale)
input_global_scale_inv = layer.input_global_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(
@@ -121,4 +143,5 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
layer=layer,
x=x,
bias=bias,
swizzle=self.swizzle,
)

View File

@@ -71,6 +71,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
NvFp4LinearBackend,
apply_nvfp4_linear,
convert_to_nvfp4_linear_kernel_format,
select_nvfp4_linear_backend,
@@ -1074,6 +1075,10 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
self.marlin_input_dtype = None
self.backend = select_nvfp4_linear_backend()
self.swizzle = None
if self.backend == NvFp4LinearBackend.EMULATION:
self.swizzle = False
def create_weights(
self,
layer: torch.nn.Module,
@@ -1149,10 +1154,23 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if (
torch.unique(layer.input_scale).numel() != 1
or torch.unique(layer.weight_scale_2).numel() != 1
):
logger.warning_once(
"In NVFP4 linear, the global scale for input or weight are different"
" for parallel layers (e.g. q_proj, k_proj, v_proj). This "
" will likely results in reduce accuracy. Please verify the model"
" accuracy. Consider using a checkpoint with a shared global NVFP4"
" scale for parallel layers."
)
# Rename ModelOpt checkpoint names to standardized names
input_global_scale = layer.input_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
del layer.input_scale
weight_global_scale = layer.weight_scale_2.max().to(torch.float32)
layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
del layer.weight_scale_2
@@ -1179,6 +1197,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer=layer,
x=x,
bias=bias,
swizzle=self.swizzle,
)

View File

@@ -24,7 +24,7 @@ logger = init_logger(__name__)
def is_fp4_marlin_supported():
return current_platform.has_device_capability(75)
return current_platform.is_cuda() and current_platform.has_device_capability(75)
def _nvfp4_compute_scale_factor(

View File

@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
import torch
from vllm.scalar_type import scalar_types
@@ -11,9 +13,10 @@ __all__ = [
]
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT4_E2M1_MAX_RECIPROCAL = 1 / FLOAT4_E2M1_MAX
kE2M1ToFloat = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
kE2M1ToFloat_handle = SimpleNamespace(
val=torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32)
)
@@ -29,8 +32,9 @@ def break_fp4_bytes(a, dtype):
# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long)
kE2M1 = kE2M1ToFloat_handle.val
# Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
# Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype)
@@ -47,7 +51,12 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
def dequantize_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
tensor_fp4: torch.Tensor,
tensor_sf: torch.Tensor,
global_scale: torch.Tensor | float,
dtype: torch.dtype,
block_size: int = 16,
swizzle: bool | None = True,
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
@@ -57,8 +66,10 @@ def dequantize_to_dtype(
tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
if swizzle:
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) * global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
@@ -67,7 +78,8 @@ def dequantize_to_dtype(
def get_reciprocal(x):
if isinstance(x, torch.Tensor):
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
# torch.where yields operation not permitted when stream is capturing.
return 1.0 / (x + (x == 0) * 1e8)
elif isinstance(x, (float, int)):
return 0.0 if x == 0 else 1.0 / x
else:
@@ -94,7 +106,7 @@ def ref_nvfp4_quant(x, global_scale, block_size):
m, n = x.shape
x = torch.reshape(x, (m, n // block_size, block_size))
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
scale = global_scale * (vec_max * FLOAT4_E2M1_MAX_RECIPROCAL)
scale = torch.clamp(scale, max=448, min=-448)
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
@@ -111,6 +123,7 @@ def run_nvfp4_emulations(
weight: torch.Tensor,
weight_scale_swizzled: torch.Tensor,
weight_global_scale: torch.Tensor,
swizzle: bool | None = True,
):
group_size = 16
x_m, x_k = x.shape
@@ -132,8 +145,8 @@ def run_nvfp4_emulations(
weight_scale_swizzled.data,
weight_global_scale,
output_dtype,
x.device,
group_size,
swizzle=swizzle,
)
# matmul

View File

@@ -17,31 +17,99 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
kE2M1ToFloat_handle,
run_nvfp4_emulations,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
from vllm.utils.import_utils import has_fbgemm_gpu
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
# NOTE: This is ordered by preferred backend.
# Example: if both are available, FLASHINFER_CUTLASS is preferred to VLLM_CUTLASS.
class NvFp4LinearBackend(Enum):
VLLM_CUTLASS = "cutlass"
FLASHINFER_CUTLASS = "flashinfer-cutlass"
VLLM_CUTLASS = "cutlass"
MARLIN = "marlin"
FLASHINFER_TRTLLM = "flashinfer-trtllm"
FLASHINFER_CUDNN = "flashinfer-cudnn"
FBGEMM = "fbgemm"
MARLIN = "marlin"
EMULATION = "emulation"
NVFP4_LINEAR_BACKENDS = list(NvFp4LinearBackend)
def is_backend_supported(backend: NvFp4LinearBackend) -> tuple[bool, str | None]:
reason = None
supported = True
if backend == NvFp4LinearBackend.FLASHINFER_CUTLASS:
# cutlass_fp4_supported() checks that the vLLM NVFP4 kernels (both
# quantization and GEMM) were compiled for the current SM version.
# FlashInfer backends still rely on the vLLM quantization kernels,
# so we gate them on the same check.
supported = (
cutlass_fp4_supported()
and current_platform.has_device_capability(100)
and has_flashinfer()
)
if not supported:
reason = "FlashInfer is required, >=sm_100 is required"
elif backend == NvFp4LinearBackend.VLLM_CUTLASS:
supported = cutlass_fp4_supported()
if not supported:
reason = "Cutlass is required"
elif backend == NvFp4LinearBackend.MARLIN:
supported = is_fp4_marlin_supported()
if not supported:
reason = "Marlin is required"
elif backend in [
NvFp4LinearBackend.FLASHINFER_TRTLLM,
NvFp4LinearBackend.FLASHINFER_CUDNN,
]:
supported = has_flashinfer()
if not supported:
reason = "FlashInfer is required"
elif backend == NvFp4LinearBackend.FBGEMM:
supported = has_fbgemm_gpu()
if not supported:
reason = "fbgemm_gpu is required"
elif backend == NvFp4LinearBackend.EMULATION:
# e.g. AMD Instinct does not support native NVFP4.
unsupported_reasons = {}
for other_backend in NVFP4_LINEAR_BACKENDS:
if other_backend == NvFp4LinearBackend.EMULATION:
continue
other_supported, other_reason = is_backend_supported(other_backend)
if not other_supported:
unsupported_reasons[other_backend] = other_reason
if unsupported_reasons:
unsupported_reasons_str = "\n - ".join(
[f"{b.value}: {r}" for b, r in unsupported_reasons.items()]
)
logger.warning_once(
f"NVFP4 linear falling back to the slow and unoptimized "
f"backend=NvFp4LinearBackend.EMULATION as no optimized backend is "
f"available (unavailable reasons:\n - {unsupported_reasons_str}\n). "
"In case you expect one of these backend to be used, "
"please verify your environment."
)
return supported, reason
def select_nvfp4_linear_backend() -> NvFp4LinearBackend:
"""
Select the best available NVFP4 GEMM backend based on environment
configuration and platform capabilities.
"""
backend: NvFp4LinearBackend | None = None
selected_backend: NvFp4LinearBackend | None = None
if envs.VLLM_USE_FBGEMM:
try:
@@ -51,51 +119,36 @@ def select_nvfp4_linear_backend() -> NvFp4LinearBackend:
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
) from exc
backend = NvFp4LinearBackend.FBGEMM
selected_backend = NvFp4LinearBackend.FBGEMM
elif envs.VLLM_USE_NVFP4_CT_EMULATIONS:
backend = NvFp4LinearBackend.EMULATION
selected_backend = NvFp4LinearBackend.EMULATION
elif envs.VLLM_NVFP4_GEMM_BACKEND is None:
# Auto-select best available backend.
# cutlass_fp4_supported() checks that the vLLM NVFP4 kernels (both
# quantization and GEMM) were compiled for the current SM version.
# FlashInfer backends still rely on the vLLM quantization kernels,
# so we gate them on the same check.
if (
cutlass_fp4_supported()
and current_platform.has_device_capability(100)
and has_flashinfer()
):
backend = NvFp4LinearBackend.FLASHINFER_CUTLASS
elif cutlass_fp4_supported():
backend = NvFp4LinearBackend.VLLM_CUTLASS
elif is_fp4_marlin_supported():
backend = NvFp4LinearBackend.MARLIN
for backend in NVFP4_LINEAR_BACKENDS:
supported, reason = is_backend_supported(backend)
if supported:
selected_backend = backend
break
else:
backend = NvFp4LinearBackend(envs.VLLM_NVFP4_GEMM_BACKEND)
selected_backend = NvFp4LinearBackend(envs.VLLM_NVFP4_GEMM_BACKEND)
# Validate that the backend is supported
if backend in (
NvFp4LinearBackend.FLASHINFER_CUTLASS,
NvFp4LinearBackend.FLASHINFER_TRTLLM,
NvFp4LinearBackend.FLASHINFER_CUDNN,
):
assert has_flashinfer(), f"FlashInfer is required for {backend}"
assert cutlass_fp4_supported(), (
f"{backend} requires vLLM NVFP4 quantization kernels compiled "
f"for the current GPU (SM {current_platform.get_device_capability()})"
)
elif backend == NvFp4LinearBackend.VLLM_CUTLASS:
assert cutlass_fp4_supported(), f"Cutlass is required for {backend}"
elif backend == NvFp4LinearBackend.MARLIN:
assert is_fp4_marlin_supported(), f"Marlin is required for {backend}"
elif backend is None:
if selected_backend is None:
raise ValueError(
f"No NVFP4 GEMM backend selected, "
f"available backends: {list(NvFp4LinearBackend)}"
f"available backends: {NVFP4_LINEAR_BACKENDS}"
)
logger.info_once(f"Using {backend} for NVFP4 GEMM")
return backend
supported, reason = is_backend_supported(selected_backend)
if not supported:
raise ValueError(
f"The selected backend={selected_backend} is not supported in current "
f"environment. Reason: {reason}. Current environment: "
f"{envs.VLLM_USE_FBGEMM=}, {envs.VLLM_USE_NVFP4_CT_EMULATIONS=}, "
f"{envs.VLLM_NVFP4_GEMM_BACKEND}."
)
logger.info_once(f"Using {selected_backend} for NVFP4 GEMM")
return selected_backend
def prepare_weights_for_nvfp4_flashinfer_trtllm(
@@ -183,6 +236,10 @@ def convert_to_nvfp4_linear_kernel_format(
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
layer.weights_padding_cols = weights_padding_cols
elif backend == NvFp4LinearBackend.EMULATION:
# We can not call `.to(device)` during cuda graph capture - do it here instead.
# (operation not permitted when stream is capturing)
kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(layer.weight.device)
def apply_nvfp4_linear(
@@ -190,6 +247,7 @@ def apply_nvfp4_linear(
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
swizzle: bool | None = None,
) -> torch.Tensor:
"""
Apply NVFP4 linear transformation using the specified backend.
@@ -220,6 +278,7 @@ def apply_nvfp4_linear(
weight=weight,
weight_scale_swizzled=weight_scale,
weight_global_scale=weight_global_scale,
swizzle=swizzle,
)
if bias is not None:
out = out + bias

View File

@@ -409,6 +409,7 @@ class RocmPlatform(Platform):
"mxfp4",
"torchao",
"bitsandbytes",
"modelopt_fp4",
]
@classmethod

View File

@@ -461,3 +461,8 @@ def has_aiter() -> bool:
def has_mori() -> bool:
"""Whether the optional `mori` package is available."""
return _has_module("mori")
def has_fbgemm_gpu() -> bool:
"""Whether the optional `fbgemm_gpu` package is available."""
return _has_module("fbgemm_gpu")