[NVIDIA] Bugfix NVFP4 DGX Spark and RTX50 (#38423)

Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
Signed-off-by: Johnny <johnnynuca14@gmail.com>
This commit is contained in:
Johnny
2026-03-30 18:36:18 +02:00
committed by GitHub
parent 8e6293e838
commit b4a2f3ac36
15 changed files with 86 additions and 20 deletions

View File

@@ -16,6 +16,7 @@
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
@@ -53,12 +54,27 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor const& output_scale_offset_by_experts);
#endif
static bool nvfp4_quant_sm_supported() {
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) return true;
#endif
return false;
}
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
@@ -100,6 +116,10 @@ void scaled_fp4_experts_quant(
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 experts quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
@@ -112,6 +132,10 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
torch::Tensor& input, torch::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
@@ -125,6 +149,11 @@ void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 experts quantization kernel "
"for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);

View File

@@ -63,5 +63,17 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
int runtimeVersion;
cudaRuntimeGetVersion(&runtimeVersion);
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
if (runtimeVersion < 12080) return false;
// Only report support when the SM-specific kernel was actually compiled in,
// so the Python-side backend selector does not choose CUTLASS and then hit
// TORCH_CHECK_NOT_IMPLEMENTED (or worse, fall through to Marlin).
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (cuda_device_capability >= 100 && cuda_device_capability < 120)
return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (cuda_device_capability >= 120 && cuda_device_capability < 130)
return true;
#endif
return false;
}

View File

@@ -154,6 +154,7 @@ struct MacheteCollectiveMma {
struct DispatchPolicy {
constexpr static int Stages = PipelineStages;
using ClusterShape = ClusterShape_MNK;
using ArchTag = arch::Sm90;
using Schedule = KernelScheduleType;
};