Fix: Add out= parameter to run_fused_swiglu_grouped_gemm signature
This commit is contained in:
@@ -416,6 +416,7 @@ def run_fused_swiglu_grouped_gemm(
|
||||
swiglu_limit=0.0,
|
||||
mma_tiler_mn=(128, 128),
|
||||
cluster_shape_mn=(1, 1),
|
||||
out=None, # pre-allocated output buffer for CUDA graph capture
|
||||
):
|
||||
"""Run the fused SwiGLU NVFP4 scaled grouped GEMM.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user