[BugFix][Kernel] Fix Illegal memory access in causal_conv1d in H100 (#9838)

Signed-off-by: mzusman <mor.zusmann@gmail.com>
This commit is contained in:
Mor Zusman
2024-10-31 22:06:25 +02:00
committed by GitHub
parent 55650c83a0
commit 9fb12f7848
3 changed files with 40 additions and 7 deletions

View File

@@ -151,7 +151,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
@@ -420,7 +420,10 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
assert torch.allclose(final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol)
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,