Add cutlass support for blackwell fp8 blockwise gemm (#14383)
Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
@@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm,
|
||||
(out, a, b, scale_a, scale_b, bias))
|
||||
@@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
|
||||
return
|
||||
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % 4 != 0 and current_platform.has_device_capability(100):
|
||||
return
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user