fix: add fp4_out/sf_out/l2_global_scale params to fused_swiglu kernel() signature

The __call__ method passes these 3 Optional params to self.kernel(),
but kernel() didn't accept them, causing TypeError: too many positional
arguments during cute.compile(). This was the CuTeDSL 'arg-binding bug'
blocking P0/P1.
This commit is contained in:
2026-06-02 08:11:18 +00:00
parent 55ea109cca
commit fca72427ea

View File

@@ -1285,6 +1285,10 @@ class FusedSwiGLUScaledGroupedGemmKernel:
# ── Optional: NVFP4 per-expert global scales ── # ── Optional: NVFP4 per-expert global scales ──
global_scale_a: Optional[cute.Tensor], global_scale_a: Optional[cute.Tensor],
global_scale_b: Optional[cute.Tensor], global_scale_b: Optional[cute.Tensor],
# ── Fused SwiGLU epilogue outputs (replaces out when fused_swiglu=True) ──
fp4_out: Optional[cute.Tensor] = None,
sf_out: Optional[cute.Tensor] = None,
l2_global_scale: Optional[cute.Tensor] = None,
): ):
""" """
GPU device kernel for MoE Scaled Grouped GEMM with block scaling. GPU device kernel for MoE Scaled Grouped GEMM with block scaling.