[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)

This commit is contained in:
Mor Zusman
2024-10-17 00:12:43 +08:00
committed by GitHub
parent 415f76a9cb
commit fb60ae9b91
15 changed files with 504 additions and 432 deletions

View File

@@ -5,6 +5,7 @@ from einops import rearrange, repeat
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.utils import seed_everything
@@ -174,7 +175,8 @@ def selective_scan_opcheck_fn(u,
cu_seq_len=None,
cache_indices=None,
has_initial_state=None,
ssm_states=None):
ssm_states=None,
pad_slot_id=PAD_SLOT_ID):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
"""
@@ -203,7 +205,7 @@ def selective_scan_opcheck_fn(u,
# a bogus error.
opcheck(torch.ops._C.selective_scan_fwd,
(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len,
cache_indices, has_initial_state, ssm_states),
cache_indices, has_initial_state, ssm_states, pad_slot_id),
test_utils=["test_schema", "test_faketensor"])
@@ -404,9 +406,12 @@ def test_selective_state_update(dim, dstate, has_z, itype):
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
has_D, has_z, has_delta_bias, delta_softplus,
return_last_state, seqlen, itype, wtype):
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [False, True])
def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
varBC_groups, has_D, has_z, has_delta_bias,
delta_softplus, return_last_state, seqlen,
itype, wtype):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
@@ -420,18 +425,27 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
# set seed
torch.random.manual_seed(0)
seqlens = []
nsplits = 3
batch_size = 4
if seqlen < 10:
nsplits = 0
batch_size = 1
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
if with_padding and seqlen < padded_batch_size:
pytest.skip()
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0).cuda()
@@ -462,22 +476,33 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
delta_ref = delta.clone()
out = None
out_ref = None
prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1]))
prev_state_shape = (total_entries, u.shape[0], int(A.shape[1]))
prev_state = torch.randn(prev_state_shape,
device=u.device,
dtype=itype,
requires_grad=False)
prev_state_ref = prev_state.clone()
cache_indices = torch.randperm(cumsum.shape[0] - 1,
state_indices = torch.randperm(total_entries,
dtype=torch.int32,
device=u.device)
device=u.device)[:batch_size]
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
has_initial_state = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=u.device)
out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, cache_indices,
delta_softplus, cumsum, padded_state_indices,
has_initial_state)
outs_ref = []
splits = [
@@ -486,6 +511,8 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
]
for i in range(len(seqlens[0])):
u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_s, _ = selective_scan_ref(
u_s,
delta_s,
@@ -497,21 +524,22 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state,
prev_state=prev_state_ref[cache_indices[i]].unsqueeze(0)
prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0)
if has_initial_state[i] else None,
final_state_out=prev_state_ref[cache_indices[i]].unsqueeze(0))
final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(
0))
outs_ref.append(out_ref_s)
out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0]
out_ref = torch.cat(outs_ref, dim=-1)[0]
print("Output diff max", (out - out_ref[0]).max())
print("Output diff mean", (out - out_ref[0]).mean())
unpadded_out = out[:, :out_ref[0].shape[-1]]
print("Output diff max", (unpadded_out - out_ref).max())
print("Output diff mean", (unpadded_out - out_ref).mean())
print("Output state diff max", (prev_state - prev_state_ref).max())
print("Output state diff mean", (prev_state - prev_state_ref).mean())
assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol)
assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol)
selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, cache_indices,
delta_softplus, cumsum, padded_state_indices,
has_initial_state, prev_state)
@@ -520,7 +548,10 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
has_z, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
@@ -530,21 +561,32 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
# set seed
torch.random.manual_seed(0)
batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
B = torch.randn(padded_batch_size, dstate, device=device)
C = torch.randn(padded_batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone()
state_ref = state[state_indices, :].clone()
state_before = state.clone()
out = selective_state_update(state,
x,
dt,
@@ -555,15 +597,16 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices)
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = selective_state_update_ref(state_ref,
x,
dt,
x[:batch_size],
dt[:batch_size],
A,
B,
C,
B[:batch_size],
C[:batch_size],
D=D,
z=z,
z=z[:batch_size],
dt_bias=dt_bias,
dt_softplus=True)
@@ -572,11 +615,21 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
print("Output state diff max", (state[state_indices, :] - state_ref).max())
print("Output state diff mean",
(state[state_indices, :] - state_ref).mean())
# test padded entries stay the same
if with_padding:
assert torch.equal(state_before[unused_states_bool],
state[unused_states_bool])
assert torch.equal(x[batch_size + 1:], x[batch_size + 1:])
assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:])
assert torch.equal(B[batch_size + 1:], B[batch_size + 1:])
assert torch.equal(C[batch_size + 1:], C[batch_size + 1:])
# test "real" entries
assert torch.allclose(state[state_indices, :],
state_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype",
@@ -645,7 +698,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices)
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = selective_state_update_ref(state_ref,
x,
dt,