[Kernel] SM90 CUTLASS FP8 GEMM: add support for swap AB + kernel tuning (#20396)

Signed-off-by: Faqin Zhong <faqin.zhong@gmail.com>
Co-authored-by: Duncan Moss <djm.moss@gmail.com>
This commit is contained in:
lyrisz
2025-07-28 16:13:58 -07:00
committed by GitHub
parent 8aa1485fcf
commit c6c9122d50
3 changed files with 276 additions and 51 deletions

View File

@@ -96,7 +96,7 @@ def cutlass_fp8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1)
torch.testing.assert_close(out, baseline, rtol=5e-1, atol=1.5e-1)
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))