[BugFix][Kernel] Fix Illegal memory access in causal_conv1d in H100 (#9838)

Signed-off-by: mzusman <mor.zusmann@gmail.com>
This commit is contained in:
Mor Zusman
2024-10-31 22:06:25 +02:00
committed by GitHub
parent 55650c83a0
commit 9fb12f7848
3 changed files with 40 additions and 7 deletions

View File

@@ -555,7 +555,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 7e-2, 7e-2
rtol, atol = 1e-1, 1e-1
if torch.version.hip:
atol *= 2
# set seed
@@ -610,8 +610,8 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
dt_bias=dt_bias,
dt_softplus=True)
print("Output diff max", (out - out_ref[0]).max())
print("Output diff mean", (out - out_ref[0]).mean())
print("Output diff max", (out[:batch_size] - out_ref).max())
print("Output diff mean", (out[:batch_size] - out_ref).mean())
print("Output state diff max", (state[state_indices, :] - state_ref).max())
print("Output state diff mean",
(state[state_indices, :] - state_ref).mean())