[Kernel] Add NVFP4 MoE CUTLASS support for SM120 (#29242)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -67,9 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
|
||||
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \
|
||||
defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120
|
||||
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
@@ -284,8 +284,9 @@ void get_cutlass_moe_mm_data(
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k,
|
||||
@@ -296,7 +297,7 @@ void get_cutlass_moe_mm_data(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||
"CUDA device capability: ",
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes(
|
||||
@@ -304,8 +305,9 @@ void get_cutlass_moe_mm_problem_sizes(
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
|
||||
problem_sizes2, num_experts, n, k,
|
||||
blockscale_offsets);
|
||||
@@ -315,7 +317,7 @@ void get_cutlass_moe_mm_problem_sizes(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
|
||||
"kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
@@ -328,8 +330,9 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||
problem_sizes2, expert_num_tokens,
|
||||
num_local_experts, padded_m, n, k);
|
||||
@@ -339,7 +342,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
false,
|
||||
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
|
||||
"for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
Reference in New Issue
Block a user