[NVFP4] Support NVFP4 dense models from modelopt and compressed-tensors on AMD Instinct MI300, MI355X and Hopper through emulation (#35733)
Signed-off-by: Felix Marty <Felix.Marty@amd.com> Signed-off-by: fxmarty-amd <felmarty@amd.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -89,22 +89,33 @@ def test_models(example_prompts, model_name) -> None:
|
||||
|
||||
EAGER = [True, False]
|
||||
|
||||
SM_100_NVFP4_BACKENDS = [
|
||||
"flashinfer-cudnn",
|
||||
"flashinfer-trtllm",
|
||||
"flashinfer-cutlass",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(100),
|
||||
reason="modelopt_fp4 is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["nvidia/Llama-3.1-8B-Instruct-NVFP4"])
|
||||
@pytest.mark.parametrize("eager", EAGER)
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
[
|
||||
"emulation",
|
||||
"flashinfer-cudnn",
|
||||
"flashinfer-trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used
|
||||
"flashinfer-cutlass",
|
||||
],
|
||||
)
|
||||
def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch):
|
||||
if (
|
||||
not current_platform.has_device_capability(100)
|
||||
and backend in SM_100_NVFP4_BACKENDS
|
||||
):
|
||||
pytest.skip(
|
||||
f"The backend {backend} is not supported with current_platform.has_device_capability(100) == False"
|
||||
)
|
||||
|
||||
monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend)
|
||||
with vllm_runner(model, enforce_eager=eager) as llm:
|
||||
output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)
|
||||
|
||||
@@ -366,9 +366,6 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"args",
|
||||
[
|
||||
@@ -398,7 +395,7 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
|
||||
assert qkv_proj.scheme.group_size == 16
|
||||
|
||||
llm.apply_model(check_model)
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=4)
|
||||
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@@ -1464,6 +1464,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
|
||||
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
|
||||
# - "marlin": use marlin GEMM backend (for GPUs without native FP4 support)
|
||||
# - "emulation":
|
||||
# use BF16/FP16 GEMM, dequantizing weights and running QDQ on activations.
|
||||
# This is only meant for research purposes to run on devices where NVFP4
|
||||
# GEMM kernels are not available.
|
||||
# - <none>: automatically pick an available backend
|
||||
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
|
||||
"VLLM_NVFP4_GEMM_BACKEND",
|
||||
@@ -1474,6 +1478,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"flashinfer-cutlass",
|
||||
"cutlass",
|
||||
"marlin",
|
||||
"emulation",
|
||||
],
|
||||
),
|
||||
# Controls garbage collection during CUDA graph capture.
|
||||
|
||||
@@ -5,10 +5,12 @@ from collections.abc import Callable
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
NvFp4LinearBackend,
|
||||
apply_nvfp4_linear,
|
||||
convert_to_nvfp4_linear_kernel_format,
|
||||
select_nvfp4_linear_backend,
|
||||
@@ -19,6 +21,9 @@ from vllm.model_executor.parameter import (
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
__all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
|
||||
|
||||
@@ -27,6 +32,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
self.backend = select_nvfp4_linear_backend()
|
||||
self.group_size = 16
|
||||
|
||||
self.swizzle = None
|
||||
if self.backend == NvFp4LinearBackend.EMULATION:
|
||||
self.swizzle = False
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
@@ -89,6 +98,19 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
# Rename CT checkpoint names to standardized names
|
||||
layer.weight = layer.weight_packed
|
||||
del layer.weight_packed
|
||||
|
||||
if (
|
||||
torch.unique(layer.input_global_scale).numel() != 1
|
||||
or torch.unique(layer.weight_global_scale).numel() != 1
|
||||
):
|
||||
logger.warning_once(
|
||||
"In NVFP4 linear, the global scale for input or weight are different"
|
||||
" for parallel layers (e.g. q_proj, k_proj, v_proj). This "
|
||||
" will likely result in reduced accuracy. Please verify the model"
|
||||
" accuracy. Consider using a checkpoint with a shared global NVFP4"
|
||||
" scale for fused layers."
|
||||
)
|
||||
|
||||
# Process global scales (CT stores as divisors, i.e. 1/scale)
|
||||
input_global_scale_inv = layer.input_global_scale.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(
|
||||
@@ -121,4 +143,5 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
layer=layer,
|
||||
x=x,
|
||||
bias=bias,
|
||||
swizzle=self.swizzle,
|
||||
)
|
||||
|
||||
@@ -71,6 +71,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
mxfp8_e4m3_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
NvFp4LinearBackend,
|
||||
apply_nvfp4_linear,
|
||||
convert_to_nvfp4_linear_kernel_format,
|
||||
select_nvfp4_linear_backend,
|
||||
@@ -1074,6 +1075,10 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
self.marlin_input_dtype = None
|
||||
self.backend = select_nvfp4_linear_backend()
|
||||
|
||||
self.swizzle = None
|
||||
if self.backend == NvFp4LinearBackend.EMULATION:
|
||||
self.swizzle = False
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -1149,10 +1154,23 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if (
|
||||
torch.unique(layer.input_scale).numel() != 1
|
||||
or torch.unique(layer.weight_scale_2).numel() != 1
|
||||
):
|
||||
logger.warning_once(
|
||||
"In NVFP4 linear, the global scale for input or weight are different"
|
||||
" for parallel layers (e.g. q_proj, k_proj, v_proj). This "
|
||||
" will likely results in reduce accuracy. Please verify the model"
|
||||
" accuracy. Consider using a checkpoint with a shared global NVFP4"
|
||||
" scale for parallel layers."
|
||||
)
|
||||
|
||||
# Rename ModelOpt checkpoint names to standardized names
|
||||
input_global_scale = layer.input_scale.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
|
||||
del layer.input_scale
|
||||
|
||||
weight_global_scale = layer.weight_scale_2.max().to(torch.float32)
|
||||
layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
|
||||
del layer.weight_scale_2
|
||||
@@ -1179,6 +1197,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
layer=layer,
|
||||
x=x,
|
||||
bias=bias,
|
||||
swizzle=self.swizzle,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_fp4_marlin_supported():
|
||||
return current_platform.has_device_capability(75)
|
||||
return current_platform.is_cuda() and current_platform.has_device_capability(75)
|
||||
|
||||
|
||||
def _nvfp4_compute_scale_factor(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.scalar_type import scalar_types
|
||||
@@ -11,9 +13,10 @@ __all__ = [
|
||||
]
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT4_E2M1_MAX_RECIPROCAL = 1 / FLOAT4_E2M1_MAX
|
||||
|
||||
kE2M1ToFloat = torch.tensor(
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
||||
kE2M1ToFloat_handle = SimpleNamespace(
|
||||
val=torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32)
|
||||
)
|
||||
|
||||
|
||||
@@ -29,8 +32,9 @@ def break_fp4_bytes(a, dtype):
|
||||
# Vectorized sign and magnitude extraction
|
||||
signs = (combined & 0x08).to(torch.bool) # Sign bits
|
||||
abs_vals = (combined & 0x07).to(torch.long)
|
||||
|
||||
kE2M1 = kE2M1ToFloat_handle.val
|
||||
# Device-aware lookup and sign application
|
||||
kE2M1 = kE2M1ToFloat.to(device=a.device)
|
||||
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
|
||||
# Reshape to final form
|
||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||
@@ -47,7 +51,12 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
|
||||
|
||||
def dequantize_to_dtype(
|
||||
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
||||
tensor_fp4: torch.Tensor,
|
||||
tensor_sf: torch.Tensor,
|
||||
global_scale: torch.Tensor | float,
|
||||
dtype: torch.dtype,
|
||||
block_size: int = 16,
|
||||
swizzle: bool | None = True,
|
||||
):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
@@ -57,8 +66,10 @@ def dequantize_to_dtype(
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
if swizzle:
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) * global_scale
|
||||
|
||||
# scale the tensor
|
||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||
@@ -67,7 +78,8 @@ def dequantize_to_dtype(
|
||||
|
||||
def get_reciprocal(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
|
||||
# torch.where yields operation not permitted when stream is capturing.
|
||||
return 1.0 / (x + (x == 0) * 1e8)
|
||||
elif isinstance(x, (float, int)):
|
||||
return 0.0 if x == 0 else 1.0 / x
|
||||
else:
|
||||
@@ -94,7 +106,7 @@ def ref_nvfp4_quant(x, global_scale, block_size):
|
||||
m, n = x.shape
|
||||
x = torch.reshape(x, (m, n // block_size, block_size))
|
||||
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)
|
||||
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
|
||||
scale = global_scale * (vec_max * FLOAT4_E2M1_MAX_RECIPROCAL)
|
||||
scale = torch.clamp(scale, max=448, min=-448)
|
||||
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
|
||||
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
|
||||
@@ -111,6 +123,7 @@ def run_nvfp4_emulations(
|
||||
weight: torch.Tensor,
|
||||
weight_scale_swizzled: torch.Tensor,
|
||||
weight_global_scale: torch.Tensor,
|
||||
swizzle: bool | None = True,
|
||||
):
|
||||
group_size = 16
|
||||
x_m, x_k = x.shape
|
||||
@@ -132,8 +145,8 @@ def run_nvfp4_emulations(
|
||||
weight_scale_swizzled.data,
|
||||
weight_global_scale,
|
||||
output_dtype,
|
||||
x.device,
|
||||
group_size,
|
||||
swizzle=swizzle,
|
||||
)
|
||||
|
||||
# matmul
|
||||
|
||||
@@ -17,31 +17,99 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
|
||||
kE2M1ToFloat_handle,
|
||||
run_nvfp4_emulations,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
|
||||
from vllm.utils.import_utils import has_fbgemm_gpu
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# NOTE: This is ordered by preferred backend.
|
||||
# Example: if both are available, FLASHINFER_CUTLASS is preferred to VLLM_CUTLASS.
|
||||
class NvFp4LinearBackend(Enum):
|
||||
VLLM_CUTLASS = "cutlass"
|
||||
FLASHINFER_CUTLASS = "flashinfer-cutlass"
|
||||
VLLM_CUTLASS = "cutlass"
|
||||
MARLIN = "marlin"
|
||||
FLASHINFER_TRTLLM = "flashinfer-trtllm"
|
||||
FLASHINFER_CUDNN = "flashinfer-cudnn"
|
||||
FBGEMM = "fbgemm"
|
||||
MARLIN = "marlin"
|
||||
EMULATION = "emulation"
|
||||
|
||||
|
||||
NVFP4_LINEAR_BACKENDS = list(NvFp4LinearBackend)
|
||||
|
||||
|
||||
def is_backend_supported(backend: NvFp4LinearBackend) -> tuple[bool, str | None]:
|
||||
reason = None
|
||||
supported = True
|
||||
|
||||
if backend == NvFp4LinearBackend.FLASHINFER_CUTLASS:
|
||||
# cutlass_fp4_supported() checks that the vLLM NVFP4 kernels (both
|
||||
# quantization and GEMM) were compiled for the current SM version.
|
||||
# FlashInfer backends still rely on the vLLM quantization kernels,
|
||||
# so we gate them on the same check.
|
||||
supported = (
|
||||
cutlass_fp4_supported()
|
||||
and current_platform.has_device_capability(100)
|
||||
and has_flashinfer()
|
||||
)
|
||||
|
||||
if not supported:
|
||||
reason = "FlashInfer is required, >=sm_100 is required"
|
||||
elif backend == NvFp4LinearBackend.VLLM_CUTLASS:
|
||||
supported = cutlass_fp4_supported()
|
||||
if not supported:
|
||||
reason = "Cutlass is required"
|
||||
elif backend == NvFp4LinearBackend.MARLIN:
|
||||
supported = is_fp4_marlin_supported()
|
||||
if not supported:
|
||||
reason = "Marlin is required"
|
||||
elif backend in [
|
||||
NvFp4LinearBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4LinearBackend.FLASHINFER_CUDNN,
|
||||
]:
|
||||
supported = has_flashinfer()
|
||||
if not supported:
|
||||
reason = "FlashInfer is required"
|
||||
elif backend == NvFp4LinearBackend.FBGEMM:
|
||||
supported = has_fbgemm_gpu()
|
||||
if not supported:
|
||||
reason = "fbgemm_gpu is required"
|
||||
elif backend == NvFp4LinearBackend.EMULATION:
|
||||
# e.g. AMD Instinct does not support native NVFP4.
|
||||
unsupported_reasons = {}
|
||||
for other_backend in NVFP4_LINEAR_BACKENDS:
|
||||
if other_backend == NvFp4LinearBackend.EMULATION:
|
||||
continue
|
||||
other_supported, other_reason = is_backend_supported(other_backend)
|
||||
if not other_supported:
|
||||
unsupported_reasons[other_backend] = other_reason
|
||||
|
||||
if unsupported_reasons:
|
||||
unsupported_reasons_str = "\n - ".join(
|
||||
[f"{b.value}: {r}" for b, r in unsupported_reasons.items()]
|
||||
)
|
||||
logger.warning_once(
|
||||
f"NVFP4 linear falling back to the slow and unoptimized "
|
||||
f"backend=NvFp4LinearBackend.EMULATION as no optimized backend is "
|
||||
f"available (unavailable reasons:\n - {unsupported_reasons_str}\n). "
|
||||
"In case you expect one of these backend to be used, "
|
||||
"please verify your environment."
|
||||
)
|
||||
|
||||
return supported, reason
|
||||
|
||||
|
||||
def select_nvfp4_linear_backend() -> NvFp4LinearBackend:
|
||||
"""
|
||||
Select the best available NVFP4 GEMM backend based on environment
|
||||
configuration and platform capabilities.
|
||||
"""
|
||||
backend: NvFp4LinearBackend | None = None
|
||||
selected_backend: NvFp4LinearBackend | None = None
|
||||
|
||||
if envs.VLLM_USE_FBGEMM:
|
||||
try:
|
||||
@@ -51,51 +119,36 @@ def select_nvfp4_linear_backend() -> NvFp4LinearBackend:
|
||||
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
||||
"Please install with: pip install fbgemm-gpu-genai"
|
||||
) from exc
|
||||
backend = NvFp4LinearBackend.FBGEMM
|
||||
selected_backend = NvFp4LinearBackend.FBGEMM
|
||||
elif envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
backend = NvFp4LinearBackend.EMULATION
|
||||
selected_backend = NvFp4LinearBackend.EMULATION
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
# Auto-select best available backend.
|
||||
# cutlass_fp4_supported() checks that the vLLM NVFP4 kernels (both
|
||||
# quantization and GEMM) were compiled for the current SM version.
|
||||
# FlashInfer backends still rely on the vLLM quantization kernels,
|
||||
# so we gate them on the same check.
|
||||
if (
|
||||
cutlass_fp4_supported()
|
||||
and current_platform.has_device_capability(100)
|
||||
and has_flashinfer()
|
||||
):
|
||||
backend = NvFp4LinearBackend.FLASHINFER_CUTLASS
|
||||
elif cutlass_fp4_supported():
|
||||
backend = NvFp4LinearBackend.VLLM_CUTLASS
|
||||
elif is_fp4_marlin_supported():
|
||||
backend = NvFp4LinearBackend.MARLIN
|
||||
for backend in NVFP4_LINEAR_BACKENDS:
|
||||
supported, reason = is_backend_supported(backend)
|
||||
if supported:
|
||||
selected_backend = backend
|
||||
break
|
||||
else:
|
||||
backend = NvFp4LinearBackend(envs.VLLM_NVFP4_GEMM_BACKEND)
|
||||
selected_backend = NvFp4LinearBackend(envs.VLLM_NVFP4_GEMM_BACKEND)
|
||||
|
||||
# Validate that the backend is supported
|
||||
if backend in (
|
||||
NvFp4LinearBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4LinearBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4LinearBackend.FLASHINFER_CUDNN,
|
||||
):
|
||||
assert has_flashinfer(), f"FlashInfer is required for {backend}"
|
||||
assert cutlass_fp4_supported(), (
|
||||
f"{backend} requires vLLM NVFP4 quantization kernels compiled "
|
||||
f"for the current GPU (SM {current_platform.get_device_capability()})"
|
||||
)
|
||||
elif backend == NvFp4LinearBackend.VLLM_CUTLASS:
|
||||
assert cutlass_fp4_supported(), f"Cutlass is required for {backend}"
|
||||
elif backend == NvFp4LinearBackend.MARLIN:
|
||||
assert is_fp4_marlin_supported(), f"Marlin is required for {backend}"
|
||||
elif backend is None:
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"No NVFP4 GEMM backend selected, "
|
||||
f"available backends: {list(NvFp4LinearBackend)}"
|
||||
f"available backends: {NVFP4_LINEAR_BACKENDS}"
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {backend} for NVFP4 GEMM")
|
||||
return backend
|
||||
supported, reason = is_backend_supported(selected_backend)
|
||||
|
||||
if not supported:
|
||||
raise ValueError(
|
||||
f"The selected backend={selected_backend} is not supported in current "
|
||||
f"environment. Reason: {reason}. Current environment: "
|
||||
f"{envs.VLLM_USE_FBGEMM=}, {envs.VLLM_USE_NVFP4_CT_EMULATIONS=}, "
|
||||
f"{envs.VLLM_NVFP4_GEMM_BACKEND}."
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {selected_backend} for NVFP4 GEMM")
|
||||
return selected_backend
|
||||
|
||||
|
||||
def prepare_weights_for_nvfp4_flashinfer_trtllm(
|
||||
@@ -183,6 +236,10 @@ def convert_to_nvfp4_linear_kernel_format(
|
||||
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
layer.weights_padding_cols = weights_padding_cols
|
||||
elif backend == NvFp4LinearBackend.EMULATION:
|
||||
# We can not call `.to(device)` during cuda graph capture - do it here instead.
|
||||
# (operation not permitted when stream is capturing)
|
||||
kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(layer.weight.device)
|
||||
|
||||
|
||||
def apply_nvfp4_linear(
|
||||
@@ -190,6 +247,7 @@ def apply_nvfp4_linear(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
swizzle: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply NVFP4 linear transformation using the specified backend.
|
||||
@@ -220,6 +278,7 @@ def apply_nvfp4_linear(
|
||||
weight=weight,
|
||||
weight_scale_swizzled=weight_scale,
|
||||
weight_global_scale=weight_global_scale,
|
||||
swizzle=swizzle,
|
||||
)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
|
||||
@@ -409,6 +409,7 @@ class RocmPlatform(Platform):
|
||||
"mxfp4",
|
||||
"torchao",
|
||||
"bitsandbytes",
|
||||
"modelopt_fp4",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -461,3 +461,8 @@ def has_aiter() -> bool:
|
||||
def has_mori() -> bool:
|
||||
"""Whether the optional `mori` package is available."""
|
||||
return _has_module("mori")
|
||||
|
||||
|
||||
def has_fbgemm_gpu() -> bool:
|
||||
"""Whether the optional `fbgemm_gpu` package is available."""
|
||||
return _has_module("fbgemm_gpu")
|
||||
|
||||
Reference in New Issue
Block a user