[Kernel] Add non-gated support for NVFP4 CUTLASS MoE (#37320)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-03-17 23:12:04 +01:00
committed by GitHub
parent 3ed7b1e6e0
commit 09e4576f65
8 changed files with 53 additions and 26 deletions

View File

@@ -75,7 +75,8 @@ void get_cutlass_moe_mm_data_caller(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::Tensor& expert_first_token_offset,
@@ -278,7 +279,8 @@ void get_cutlass_moe_mm_data(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) {
const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
@@ -288,7 +290,7 @@ void get_cutlass_moe_mm_data(
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
blockscale_offsets);
blockscale_offsets, is_gated);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(