[BugFix][Kernel] Fix Illegal memory access in causal_conv1d in H100 (#9838)
Signed-off-by: mzusman <mor.zusmann@gmail.com>
This commit is contained in:
@@ -151,7 +151,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize(
|
||||
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
||||
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
|
||||
@pytest.mark.parametrize('dim', [64])
|
||||
@pytest.mark.parametrize('batch', [1])
|
||||
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
@@ -420,7 +420,10 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
|
||||
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(final_states[state_indices],
|
||||
final_states_ref[state_indices],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
||||
padded_state_indices, has_initial_states,
|
||||
|
||||
Reference in New Issue
Block a user