[Kernel] Add non-gated support for NVFP4 CUTLASS MoE (#37320)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -507,11 +507,12 @@ def run_cutlass_moe_fp4(
|
||||
# Gemm 1
|
||||
a: Input tensor: [m, k] (half/bfloat16)
|
||||
a1_gscale: Activation scale per expert: [e] (float32)
|
||||
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
|
||||
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
|
||||
w1 (not an argument to cutlass_moe_fp4): [e, w1_n, k]
|
||||
w1_fp4: [e, w1_n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
|
||||
where w1_n = 2*n for gated activations (gate+up), n for non-gated (up only).
|
||||
(Note: `n` is the up projection output dim, `k` is the input dim in
|
||||
full precision)
|
||||
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
|
||||
w1_blockscale: [e, w1_n, k // block_size] (float8_e4m3)
|
||||
(Block size = 16 for NVFP4)
|
||||
|
||||
# Gemm 2
|
||||
@@ -528,6 +529,11 @@ def run_cutlass_moe_fp4(
|
||||
|
||||
assumes that topk < k < n to satisfy - up/down projection expectations.
|
||||
"""
|
||||
is_gated = activation.is_gated
|
||||
# For gated activations (e.g. SiLU), w1 output is 2*n (gate + up).
|
||||
# For non-gated activations (e.g. SiLU_NO_MUL), w1 output is n (up only).
|
||||
w1_n = n * 2 if is_gated else n
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
|
||||
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
|
||||
@@ -538,7 +544,7 @@ def run_cutlass_moe_fp4(
|
||||
and w2_blockscale.ndim == 3
|
||||
), "All Weights must be of rank 3 for cutlass_moe_fp4"
|
||||
m_a, k_a = a.shape
|
||||
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
|
||||
e_w1, w1_n_actual, half_k_w1 = w1_fp4.shape
|
||||
e_w2, k_w2, half_n_w2 = w2_fp4.shape
|
||||
|
||||
assert e_w1 == e_w2 and e_w1 == e, (
|
||||
@@ -548,7 +554,7 @@ def run_cutlass_moe_fp4(
|
||||
assert k_a == half_k_w1 * 2 and k == k_w2, (
|
||||
"Hidden size mismatch between a, w1 and w2"
|
||||
)
|
||||
assert nx2_w1 == n * 2 and half_n_w2 * 2 == n, "mismatch in expected `n`"
|
||||
assert w1_n_actual == w1_n and half_n_w2 * 2 == n, "mismatch in expected `n`"
|
||||
assert m == m_a, "input shape mismatch"
|
||||
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||
@@ -589,6 +595,7 @@ def run_cutlass_moe_fp4(
|
||||
n,
|
||||
k,
|
||||
blockscale_offsets,
|
||||
is_gated=is_gated,
|
||||
)
|
||||
|
||||
a = ops.shuffle_rows(a, a_map)
|
||||
@@ -599,7 +606,7 @@ def run_cutlass_moe_fp4(
|
||||
blockscale_offsets,
|
||||
num_topk,
|
||||
)
|
||||
c1 = _resize_cache(workspace13, (m * topk, n * 2))
|
||||
c1 = _resize_cache(workspace13, (m * topk, w1_n))
|
||||
c2 = _resize_cache(workspace2, (m * topk, n))
|
||||
c3 = _resize_cache(workspace13, (m * topk, k))
|
||||
ops.cutlass_fp4_moe_mm(
|
||||
@@ -681,7 +688,7 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
@@ -695,11 +702,16 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
# SILU uses a fused silu+mul+fp4_quant kernel path.
|
||||
# Other gated activations use the generic apply_moe_activation()
|
||||
# fallback + separate fp4 quantization in run_cutlass_moe_fp4().
|
||||
# Non-gated activations (_NO_MUL) are also supported for models
|
||||
# like Nemotron-Nano that don't use gated MLP.
|
||||
return activation in [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
MoEActivation.SWIGLUSTEP,
|
||||
MoEActivation.SILU_NO_MUL,
|
||||
MoEActivation.GELU_NO_MUL,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user