fix: add fp4_out/sf_out/l2_global_scale params to fused_swiglu kernel() signature
The __call__ method passes these 3 Optional params to self.kernel(), but kernel() didn't accept them, causing TypeError: too many positional arguments during cute.compile(). This was the CuTeDSL 'arg-binding bug' blocking P0/P1.
This commit is contained in:
@@ -1285,6 +1285,10 @@ class FusedSwiGLUScaledGroupedGemmKernel:
|
|||||||
# ── Optional: NVFP4 per-expert global scales ──
|
# ── Optional: NVFP4 per-expert global scales ──
|
||||||
global_scale_a: Optional[cute.Tensor],
|
global_scale_a: Optional[cute.Tensor],
|
||||||
global_scale_b: Optional[cute.Tensor],
|
global_scale_b: Optional[cute.Tensor],
|
||||||
|
# ── Fused SwiGLU epilogue outputs (replaces out when fused_swiglu=True) ──
|
||||||
|
fp4_out: Optional[cute.Tensor] = None,
|
||||||
|
sf_out: Optional[cute.Tensor] = None,
|
||||||
|
l2_global_scale: Optional[cute.Tensor] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
GPU device kernel for MoE Scaled Grouped GEMM with block scaling.
|
GPU device kernel for MoE Scaled Grouped GEMM with block scaling.
|
||||||
|
|||||||
Reference in New Issue
Block a user