Support CUTLASS NVFP4 (w4a4) for Blackwell Geforce GPUs (SM120) (#21309)
Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es>
This commit is contained in:
committed by
GitHub
parent
3f36c325fa
commit
789562c28c
@@ -335,7 +335,7 @@ void run_fp4_blockwise_scaled_group_mm(
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
#endif
|
||||
@@ -356,7 +356,7 @@ void cutlass_fp4_group_mm(
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
||||
// Input validation
|
||||
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
|
||||
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
|
||||
@@ -398,7 +398,7 @@ void cutlass_fp4_group_mm(
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
|
||||
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
|
||||
"be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA "
|
||||
"12.8 or above.");
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user