Add cutlass support for blackwell fp8 blockwise gemm (#14383)

Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
Shu Wang
2025-05-08 17:09:55 -05:00
committed by GitHub
parent 4f605a6de5
commit 376786fac1
11 changed files with 332 additions and 64 deletions

View File

@@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1)
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))
@@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
return
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
return
if m % 4 != 0 and current_platform.has_device_capability(100):
return
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)