[Perf][Kernel] Fused SiLU+Mul+Quant kernel for NVFP4 cutlass_moe (#31832)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -549,7 +549,8 @@ def run_cutlass_moe_fp4(
|
||||
num_topk,
|
||||
)
|
||||
c1 = _resize_cache(workspace13, (m * topk, n * 2))
|
||||
c2 = _resize_cache(workspace2, (m * topk, n))
|
||||
# Note: c2 workspace is no longer needed since SiLU is fused with quantization.
|
||||
# c3 reuses workspace13 after c1 is consumed.
|
||||
c3 = _resize_cache(workspace13, (m * topk, k))
|
||||
ops.cutlass_fp4_moe_mm(
|
||||
c1,
|
||||
@@ -563,9 +564,9 @@ def run_cutlass_moe_fp4(
|
||||
blockscale_offsets[:-1],
|
||||
)
|
||||
del rep_a_fp4, rep_a_blockscale
|
||||
torch.ops._C.silu_and_mul(c2, c1)
|
||||
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
|
||||
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk
|
||||
# Fused SiLU+Mul+NVFP4 quantization
|
||||
int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant(
|
||||
c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk
|
||||
)
|
||||
|
||||
ops.cutlass_fp4_moe_mm(
|
||||
|
||||
Reference in New Issue
Block a user