fix: add fp4_out/sf_out/l2_global_scale params to fused_swiglu kernel() signature
The __call__ method passes these 3 Optional params to self.kernel(), but kernel() didn't accept them, causing TypeError: too many positional arguments during cute.compile(). This was the CuTeDSL 'arg-binding bug' blocking P0/P1.
This commit is contained in:
@@ -1285,6 +1285,10 @@ class FusedSwiGLUScaledGroupedGemmKernel:
|
||||
# ── Optional: NVFP4 per-expert global scales ──
|
||||
global_scale_a: Optional[cute.Tensor],
|
||||
global_scale_b: Optional[cute.Tensor],
|
||||
# ── Fused SwiGLU epilogue outputs (replaces out when fused_swiglu=True) ──
|
||||
fp4_out: Optional[cute.Tensor] = None,
|
||||
sf_out: Optional[cute.Tensor] = None,
|
||||
l2_global_scale: Optional[cute.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
GPU device kernel for MoE Scaled Grouped GEMM with block scaling.
|
||||
|
||||
Reference in New Issue
Block a user