[Kernel] SM90 CUTLASS FP8 GEMM: add support for swap AB + kernel tuning (#20396)
Signed-off-by: Faqin Zhong <faqin.zhong@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com>
This commit is contained in:
@@ -96,7 +96,7 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1)
|
||||
torch.testing.assert_close(out, baseline, rtol=5e-1, atol=1.5e-1)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm,
|
||||
(out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
Reference in New Issue
Block a user