[Kernel] Add non-gated support for NVFP4 CUTLASS MoE (#37320)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-03-17 23:12:04 +01:00
committed by GitHub
parent 3ed7b1e6e0
commit 09e4576f65
8 changed files with 53 additions and 26 deletions

View File

@@ -262,7 +262,8 @@ void get_cutlass_moe_mm_data(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::Tensor& expert_first_token_offset,

View File

@@ -17,8 +17,11 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
const int topk_length, const int n,
const int k) {
const int k, const bool is_gated) {
int expert_id = blockIdx.x;
// For gated activations (gate + up), first GEMM output is 2*n.
// For non-gated activations (up only), first GEMM output is n.
int const n1 = is_gated ? 2 * n : n;
int occurrences = 0;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
@@ -31,13 +34,13 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int final_occurrences = atomic_buffer[expert_id];
if constexpr (!SWAP_AB) {
problem_sizes1[expert_id * 3] = final_occurrences;
problem_sizes1[expert_id * 3 + 1] = 2 * n;
problem_sizes1[expert_id * 3 + 1] = n1;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = final_occurrences;
problem_sizes2[expert_id * 3 + 1] = k;
problem_sizes2[expert_id * 3 + 2] = n;
} else {
problem_sizes1[expert_id * 3] = 2 * n;
problem_sizes1[expert_id * 3] = n1;
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = k;
@@ -107,13 +110,11 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
}
namespace {
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& atomic_buffer,
int64_t num_experts, int64_t n,
int64_t k, cudaStream_t stream,
const bool swap_ab) {
inline void launch_compute_problem_sizes(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer,
int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream,
const bool swap_ab, const bool is_gated) {
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
@@ -125,7 +126,7 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
static_cast<int>(k));
static_cast<int>(k), is_gated);
});
}
} // namespace
@@ -222,7 +223,8 @@ void get_cutlass_moe_mm_data_caller(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) {
const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
@@ -236,7 +238,7 @@ void get_cutlass_moe_mm_data_caller(
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
atomic_buffer, num_experts, n, k, stream,
may_swap_ab);
may_swap_ab, is_gated);
if (blockscale_offsets.has_value()) {
// fp4 path

View File

@@ -75,7 +75,8 @@ void get_cutlass_moe_mm_data_caller(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::Tensor& expert_first_token_offset,
@@ -278,7 +279,8 @@ void get_cutlass_moe_mm_data(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) {
const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
@@ -288,7 +290,7 @@ void get_cutlass_moe_mm_data(
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
blockscale_offsets);
blockscale_offsets, is_gated);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(

View File

@@ -489,8 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k, Tensor? blockscale_offsets) -> "
"()");
" int n, int k, Tensor? blockscale_offsets, "
" bool is_gated) -> ()");
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
// compute per-expert problem sizes from expert_first_token_offset

View File

@@ -0,0 +1,5 @@
model_name: "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4"
accuracy_threshold: 0.29
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=cutlass"

View File

@@ -15,3 +15,4 @@ Mixtral-8x7B-BF16-fi-cutlass.yaml
Mixtral-8x7B-BF16-triton.yaml
Nemotron-Nano-30B-Fp8-ModelOpt-fi-trtllm.yaml
Nemotron-Nano-30B-NvFp4-ModelOpt-fi-cutlass.yaml
Nemotron-Nano-30B-NvFp4-ModelOpt-vllm-cutlass.yaml

View File

@@ -989,6 +989,7 @@ def get_cutlass_moe_mm_data(
n: int,
k: int,
blockscale_offsets: torch.Tensor | None = None,
is_gated: bool = True,
):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
@@ -1012,6 +1013,8 @@ def get_cutlass_moe_mm_data(
its computation. The number of block scale rows
computed with expert E is blockscale_offsets[E + 1] -
blockscale_offsets[E]
- is_gated: Whether the activation is gated (gate + up). When True, the
first GEMM N dimension is 2*n; when False, it is n.
"""
return torch.ops._C.get_cutlass_moe_mm_data(
topk_ids,
@@ -1024,6 +1027,7 @@ def get_cutlass_moe_mm_data(
n,
k,
blockscale_offsets,
is_gated,
)

View File

@@ -507,11 +507,12 @@ def run_cutlass_moe_fp4(
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
w1 (not an argument to cutlass_moe_fp4): [e, w1_n, k]
w1_fp4: [e, w1_n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
where w1_n = 2*n for gated activations (gate+up), n for non-gated (up only).
(Note: `n` is the up projection output dim, `k` is the input dim in
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
w1_blockscale: [e, w1_n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
@@ -528,6 +529,11 @@ def run_cutlass_moe_fp4(
assumes that topk < k < n to satisfy - up/down projection expectations.
"""
is_gated = activation.is_gated
# For gated activations (e.g. SiLU), w1 output is 2*n (gate + up).
# For non-gated activations (e.g. SiLU_NO_MUL), w1 output is n (up only).
w1_n = n * 2 if is_gated else n
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
@@ -538,7 +544,7 @@ def run_cutlass_moe_fp4(
and w2_blockscale.ndim == 3
), "All Weights must be of rank 3 for cutlass_moe_fp4"
m_a, k_a = a.shape
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w1, w1_n_actual, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert e_w1 == e_w2 and e_w1 == e, (
@@ -548,7 +554,7 @@ def run_cutlass_moe_fp4(
assert k_a == half_k_w1 * 2 and k == k_w2, (
"Hidden size mismatch between a, w1 and w2"
)
assert nx2_w1 == n * 2 and half_n_w2 * 2 == n, "mismatch in expected `n`"
assert w1_n_actual == w1_n and half_n_w2 * 2 == n, "mismatch in expected `n`"
assert m == m_a, "input shape mismatch"
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
@@ -589,6 +595,7 @@ def run_cutlass_moe_fp4(
n,
k,
blockscale_offsets,
is_gated=is_gated,
)
a = ops.shuffle_rows(a, a_map)
@@ -599,7 +606,7 @@ def run_cutlass_moe_fp4(
blockscale_offsets,
num_topk,
)
c1 = _resize_cache(workspace13, (m * topk, n * 2))
c1 = _resize_cache(workspace13, (m * topk, w1_n))
c2 = _resize_cache(workspace2, (m * topk, n))
c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_fp4_moe_mm(
@@ -681,7 +688,7 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
return True
@staticmethod
def _supports_quant_scheme(
@@ -695,11 +702,16 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
# SILU uses a fused silu+mul+fp4_quant kernel path.
# Other gated activations use the generic apply_moe_activation()
# fallback + separate fp4 quantization in run_cutlass_moe_fp4().
# Non-gated activations (_NO_MUL) are also supported for models
# like Nemotron-Nano that don't use gated MLP.
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
@staticmethod