[NVIDIA] Add support for cudnn fp4 gemm via flashinfer (#26107)
Signed-off-by: kaixih <kaixih@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -14,7 +14,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
run_nvfp4_emulations,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported,
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@@ -29,10 +32,12 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
|
||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
def __init__(self):
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
logger.info_once("Using flashinfer-trtllm for FP4")
|
||||
self.backend = "none"
|
||||
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
if has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif envs.VLLM_USE_FBGEMM:
|
||||
self.backend = "fbgemm"
|
||||
try:
|
||||
@@ -42,12 +47,17 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
||||
"Please install with: pip install fbgemm-gpu-genai"
|
||||
) from exc
|
||||
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
logger.info_once("Using flashinfer-cutlass for FP4")
|
||||
else:
|
||||
self.backend = "cutlass"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
|
||||
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
|
||||
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
|
||||
|
||||
if self.backend == "none":
|
||||
raise ValueError(
|
||||
"No valid NVFP4 GEMM backend found. "
|
||||
"Please check your platform capability."
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
@@ -184,10 +194,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
if self.backend.startswith("flashinfer-"):
|
||||
backend_name = self.backend[len("flashinfer-") :]
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||
elif self.backend == "fbgemm":
|
||||
out = torch.ops.fbgemm.f4f4bf16(
|
||||
x_fp4,
|
||||
@@ -198,6 +207,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
use_mx=False,
|
||||
).to(output_dtype)
|
||||
else:
|
||||
assert self.backend == "cutlass"
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
|
||||
@@ -926,22 +926,26 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif is_fp4_marlin_supported():
|
||||
self.backend = "marlin"
|
||||
else:
|
||||
self.backend = "none"
|
||||
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
if has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif is_fp4_marlin_supported():
|
||||
self.backend = "marlin"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
|
||||
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
|
||||
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
|
||||
|
||||
if self.backend == "none":
|
||||
raise ValueError(
|
||||
"Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above."
|
||||
"No valid NVFP4 GEMM backend found. "
|
||||
"Please check your platform capability."
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -1109,11 +1113,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
if self.backend.startswith("flashinfer-"):
|
||||
backend_name = self.backend[len("flashinfer-") :]
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||
else:
|
||||
assert self.backend == "cutlass"
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
|
||||
Reference in New Issue
Block a user