Update test_fp8.py (#159)
This commit is contained in:
@@ -108,7 +108,7 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
deep_gemm.fp8_m_grouped_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
for j in range(num_groups):
|
||||
diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()])
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}'
|
||||
assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}'
|
||||
|
||||
# Construct full cases
|
||||
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0)
|
||||
|
||||
Reference in New Issue
Block a user