[Kernel] Add non-gated support for NVFP4 CUTLASS MoE (#37320)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-03-17 23:12:04 +01:00
committed by GitHub
parent 3ed7b1e6e0
commit 09e4576f65
8 changed files with 53 additions and 26 deletions

View File

@@ -507,11 +507,12 @@ def run_cutlass_moe_fp4(
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
w1 (not an argument to cutlass_moe_fp4): [e, w1_n, k]
w1_fp4: [e, w1_n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
where w1_n = 2*n for gated activations (gate+up), n for non-gated (up only).
(Note: `n` is the up projection output dim, `k` is the input dim in
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
w1_blockscale: [e, w1_n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
@@ -528,6 +529,11 @@ def run_cutlass_moe_fp4(
assumes that topk < k < n to satisfy - up/down projection expectations.
"""
is_gated = activation.is_gated
# For gated activations (e.g. SiLU), w1 output is 2*n (gate + up).
# For non-gated activations (e.g. SiLU_NO_MUL), w1 output is n (up only).
w1_n = n * 2 if is_gated else n
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
@@ -538,7 +544,7 @@ def run_cutlass_moe_fp4(
and w2_blockscale.ndim == 3
), "All Weights must be of rank 3 for cutlass_moe_fp4"
m_a, k_a = a.shape
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w1, w1_n_actual, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert e_w1 == e_w2 and e_w1 == e, (
@@ -548,7 +554,7 @@ def run_cutlass_moe_fp4(
assert k_a == half_k_w1 * 2 and k == k_w2, (
"Hidden size mismatch between a, w1 and w2"
)
assert nx2_w1 == n * 2 and half_n_w2 * 2 == n, "mismatch in expected `n`"
assert w1_n_actual == w1_n and half_n_w2 * 2 == n, "mismatch in expected `n`"
assert m == m_a, "input shape mismatch"
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
@@ -589,6 +595,7 @@ def run_cutlass_moe_fp4(
n,
k,
blockscale_offsets,
is_gated=is_gated,
)
a = ops.shuffle_rows(a, a_map)
@@ -599,7 +606,7 @@ def run_cutlass_moe_fp4(
blockscale_offsets,
num_topk,
)
c1 = _resize_cache(workspace13, (m * topk, n * 2))
c1 = _resize_cache(workspace13, (m * topk, w1_n))
c2 = _resize_cache(workspace2, (m * topk, n))
c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_fp4_moe_mm(
@@ -681,7 +688,7 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
return True
@staticmethod
def _supports_quant_scheme(
@@ -695,11 +702,16 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
# SILU uses a fused silu+mul+fp4_quant kernel path.
# Other gated activations use the generic apply_moe_activation()
# fallback + separate fp4 quantization in run_cutlass_moe_fp4().
# Non-gated activations (_NO_MUL) are also supported for models
# like Nemotron-Nano that don't use gated MLP.
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
@staticmethod