Support CUTLASS NVFP4 (w4a4) for Blackwell Geforce GPUs (SM120) (#21309)

Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es>
This commit is contained in:
Roberto L. Castro
2025-08-03 09:54:22 +02:00
committed by GitHub
parent 3f36c325fa
commit 789562c28c
6 changed files with 329 additions and 13 deletions

View File

@@ -335,7 +335,7 @@ void run_fp4_blockwise_scaled_group_mm(
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
#endif
@@ -356,7 +356,7 @@ void cutlass_fp4_group_mm(
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
// Input validation
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
@@ -398,7 +398,7 @@ void cutlass_fp4_group_mm(
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
"be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA "
"12.8 or above.");
#endif
}