[Kernel] Change interface to Mamba causal_conv1d_update for continuous batching (#8012)
This commit is contained in:
committed by
GitHub
parent
09deb4721f
commit
8110e44529
@@ -203,3 +203,61 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
|
||||
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
||||
@pytest.mark.parametrize("width", [2, 3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
|
||||
silu_activation, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch = 64
|
||||
|
||||
x = torch.randn(batch, dim, device=device, dtype=itype)
|
||||
|
||||
total_entries = 10 * batch
|
||||
conv_state = torch.randn(total_entries,
|
||||
dim,
|
||||
width,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch].to(
|
||||
dtype=torch.int32, device=device)
|
||||
|
||||
weight = torch.randn(dim,
|
||||
width,
|
||||
device=device,
|
||||
dtype=itype,
|
||||
requires_grad=True)
|
||||
if has_bias:
|
||||
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
|
||||
else:
|
||||
bias = None
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=conv_state_indices)
|
||||
out_ref = causal_conv1d_update_ref(x,
|
||||
conv_state_ref,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
Reference in New Issue
Block a user