Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,20 +10,15 @@ from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
selective_scan_fn,
|
||||
selective_state_update,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def selective_state_update_ref(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False):
|
||||
def selective_state_update_ref(
|
||||
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
@@ -73,16 +68,17 @@ def selective_state_update_ref(state,
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
dt = dt + dt_bias
|
||||
dt = F.softplus(dt) if dt_softplus else dt
|
||||
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") *
|
||||
A) # (batch, nheads, dim, dstate)
|
||||
B = repeat(B, "b g n -> b (g h) n",
|
||||
h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
C = repeat(C, "b g n -> b (g h) n",
|
||||
h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
dA = torch.exp(
|
||||
rearrange(dt, "b h d -> b h d 1") * A
|
||||
) # (batch, nheads, dim, dstate)
|
||||
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
|
||||
B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
|
||||
state.copy_(state * dA +
|
||||
dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
|
||||
B, "b h n -> b h 1 n"
|
||||
) # (batch, nheads, dim, dstate)
|
||||
state.copy_(
|
||||
state * dA + dB * rearrange(x, "b h d -> b h d 1")
|
||||
) # (batch, dim, dstate
|
||||
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
||||
if D is not None:
|
||||
out += (x * D).to(out.dtype)
|
||||
@@ -92,18 +88,20 @@ def selective_state_update_ref(state,
|
||||
return out
|
||||
|
||||
|
||||
def selective_scan_ref(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
return_last_state=False,
|
||||
prev_state=None,
|
||||
final_state_out=None):
|
||||
def selective_scan_ref(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
return_last_state=False,
|
||||
prev_state=None,
|
||||
final_state_out=None,
|
||||
):
|
||||
"""
|
||||
u: r(B D L)
|
||||
delta: r(B D L)
|
||||
@@ -132,26 +130,26 @@ def selective_scan_ref(u,
|
||||
C = C.float()
|
||||
x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
|
||||
ys = []
|
||||
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
||||
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
|
||||
if not is_variable_B:
|
||||
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
||||
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
if B.dim() == 3:
|
||||
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
||||
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
||||
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
||||
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
||||
if is_variable_C and C.dim() == 4:
|
||||
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
||||
for i in range(u.shape[2]):
|
||||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||||
if not is_variable_C:
|
||||
y = torch.einsum('bdn,dn->bd', x, C)
|
||||
y = torch.einsum("bdn,dn->bd", x, C)
|
||||
else:
|
||||
if C.dim() == 3:
|
||||
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
||||
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
|
||||
else:
|
||||
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
||||
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
||||
if i == u.shape[2] - 1:
|
||||
if final_state_out is None:
|
||||
final_state_out = x
|
||||
@@ -166,20 +164,22 @@ def selective_scan_ref(u,
|
||||
return out if not return_last_state else (out, final_state_out)
|
||||
|
||||
|
||||
def selective_scan_opcheck_fn(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
cu_seq_len=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
ssm_states=None,
|
||||
pad_slot_id=PAD_SLOT_ID):
|
||||
def selective_scan_opcheck_fn(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
cu_seq_len=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
ssm_states=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
):
|
||||
"""if return_last_state is True, returns (out, last_state)
|
||||
last_state has shape (batch, dim, dstate).
|
||||
"""
|
||||
@@ -206,30 +206,55 @@ def selective_scan_opcheck_fn(u,
|
||||
|
||||
# Disable test_autograd_registration for now as it seems to trigger
|
||||
# a bogus error.
|
||||
opcheck(torch.ops._C.selective_scan_fwd,
|
||||
(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len,
|
||||
cache_indices, has_initial_state, ssm_states, pad_slot_id),
|
||||
test_utils=["test_schema", "test_faketensor"])
|
||||
opcheck(
|
||||
torch.ops._C.selective_scan_fwd,
|
||||
(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias,
|
||||
delta_softplus,
|
||||
cu_seq_len,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
),
|
||||
test_utils=["test_schema", "test_faketensor"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wtype', [torch.float32])
|
||||
@pytest.mark.parametrize('itype',
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize('has_delta_bias', [True])
|
||||
@pytest.mark.parametrize('delta_softplus', [True])
|
||||
@pytest.mark.parametrize('has_z', [True])
|
||||
@pytest.mark.parametrize('has_D', [True])
|
||||
@pytest.mark.parametrize("wtype", [torch.float32])
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("seqlen", [128, 256, 512, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize("has_delta_bias", [True])
|
||||
@pytest.mark.parametrize("delta_softplus", [True])
|
||||
@pytest.mark.parametrize("has_z", [True])
|
||||
@pytest.mark.parametrize("has_D", [True])
|
||||
@pytest.mark.parametrize("varBC_groups", [1, 2])
|
||||
@pytest.mark.parametrize("is_variable_C", [True])
|
||||
@pytest.mark.parametrize("is_variable_B", [True])
|
||||
@pytest.mark.parametrize("scan_chunks", [1, 2, 3])
|
||||
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
has_z, has_delta_bias, delta_softplus, seqlen, itype,
|
||||
wtype, scan_chunks):
|
||||
def test_selective_scan(
|
||||
is_variable_B,
|
||||
is_variable_C,
|
||||
varBC_groups,
|
||||
has_D,
|
||||
has_z,
|
||||
has_delta_bias,
|
||||
delta_softplus,
|
||||
seqlen,
|
||||
itype,
|
||||
wtype,
|
||||
scan_chunks,
|
||||
):
|
||||
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
|
||||
pytest.skip() # This config is not applicable
|
||||
device = 'cuda'
|
||||
device = "cuda"
|
||||
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 3e-2, 5e-2
|
||||
@@ -242,7 +267,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
batch_size = 1
|
||||
dim = 4
|
||||
dstate = 8
|
||||
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
|
||||
A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)
|
||||
A_ref = A.clone()
|
||||
if not is_variable_B:
|
||||
B_shape = [dim, dstate]
|
||||
@@ -250,9 +275,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
B_shape = [batch_size, dstate, seqlen]
|
||||
else:
|
||||
B_shape = [batch_size, varBC_groups, dstate, seqlen]
|
||||
B = torch.randn(B_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_B else itype)
|
||||
B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype)
|
||||
B_ref = B.clone()
|
||||
if not is_variable_C:
|
||||
C_shape = [dim, dstate]
|
||||
@@ -260,27 +283,27 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
C_shape = [batch_size, dstate, seqlen]
|
||||
else:
|
||||
C_shape = [batch_size, varBC_groups, dstate, seqlen]
|
||||
C = torch.randn(C_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_C else itype)
|
||||
C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
|
||||
C_ref = C.clone()
|
||||
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
|
||||
D_ref = D.clone()
|
||||
z = torch.randn(batch_size, dim, seqlen, device=device,
|
||||
dtype=itype) if has_z else None
|
||||
z = (
|
||||
torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
|
||||
if has_z
|
||||
else None
|
||||
)
|
||||
z_ref = z.clone() if has_z else None
|
||||
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
|
||||
) if has_delta_bias else None
|
||||
delta_bias = (
|
||||
(0.5 * torch.rand(dim, device=device, dtype=torch.float32))
|
||||
if has_delta_bias
|
||||
else None
|
||||
)
|
||||
u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
|
||||
u_ref = u.clone()
|
||||
delta = (0.5 *
|
||||
torch.rand(batch_size, dim, seqlen, device=device, dtype=itype))
|
||||
delta = 0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)
|
||||
delta_ref = delta.clone()
|
||||
state_shape = (batch_size, u.shape[1], int(A.shape[1]))
|
||||
state = torch.randn(state_shape,
|
||||
device=u.device,
|
||||
dtype=itype,
|
||||
requires_grad=False)
|
||||
state = torch.randn(state_shape, device=u.device, dtype=itype, requires_grad=False)
|
||||
state_ref = state.clone()
|
||||
out = None
|
||||
out_ref = None
|
||||
@@ -312,9 +335,10 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
z=_z,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
has_initial_state=torch.ones(batch_size,
|
||||
device=u.device,
|
||||
dtype=torch.bool) if c > 0 else None)
|
||||
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
|
||||
if c > 0
|
||||
else None,
|
||||
)
|
||||
outs.append(out)
|
||||
if len(outs) > 1:
|
||||
out = torch.cat(outs, dim=-1)
|
||||
@@ -329,27 +353,29 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
z=z_ref,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
return_last_state=True)
|
||||
return_last_state=True,
|
||||
)
|
||||
|
||||
assert out is not None and out_ref is not None
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
assert state is not None and state_ref is not None
|
||||
assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol)
|
||||
|
||||
selective_scan_opcheck_fn(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
ssm_states=state)
|
||||
selective_scan_opcheck_fn(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
ssm_states=state,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
@@ -374,52 +400,47 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state.detach().clone()
|
||||
selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
out=out)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
selective_state_update(
|
||||
state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
|
||||
)
|
||||
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wtype', [torch.float32])
|
||||
@pytest.mark.parametrize('itype', [torch.float32])
|
||||
@pytest.mark.parametrize('seqlen', [1, 128, 129, 256, 512, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize("wtype", [torch.float32])
|
||||
@pytest.mark.parametrize("itype", [torch.float32])
|
||||
@pytest.mark.parametrize("seqlen", [1, 128, 129, 256, 512, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize("return_last_state", [True])
|
||||
@pytest.mark.parametrize('has_delta_bias', [True])
|
||||
@pytest.mark.parametrize('delta_softplus', [True])
|
||||
@pytest.mark.parametrize('has_z', [True])
|
||||
@pytest.mark.parametrize('has_D', [True])
|
||||
@pytest.mark.parametrize("has_delta_bias", [True])
|
||||
@pytest.mark.parametrize("delta_softplus", [True])
|
||||
@pytest.mark.parametrize("has_z", [True])
|
||||
@pytest.mark.parametrize("has_D", [True])
|
||||
@pytest.mark.parametrize("varBC_groups", [1, 2])
|
||||
@pytest.mark.parametrize("is_variable_C", [True])
|
||||
@pytest.mark.parametrize("is_variable_B", [True])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [False, True])
|
||||
def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
|
||||
varBC_groups, has_D, has_z, has_delta_bias,
|
||||
delta_softplus, return_last_state, seqlen,
|
||||
itype, wtype):
|
||||
def test_selective_scan_varlen(
|
||||
with_padding,
|
||||
is_variable_B,
|
||||
is_variable_C,
|
||||
varBC_groups,
|
||||
has_D,
|
||||
has_z,
|
||||
has_delta_bias,
|
||||
delta_softplus,
|
||||
return_last_state,
|
||||
seqlen,
|
||||
itype,
|
||||
wtype,
|
||||
):
|
||||
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
|
||||
pytest.skip() # This config is not applicable
|
||||
device = 'cuda'
|
||||
device = "cuda"
|
||||
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 3e-2, 5e-2
|
||||
@@ -443,72 +464,79 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
|
||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||
seqlens.append(
|
||||
torch.diff(
|
||||
torch.cat(
|
||||
[torch.tensor([-1]), eos_pos,
|
||||
torch.tensor([seqlen - 1])])).tolist())
|
||||
torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
|
||||
).tolist()
|
||||
)
|
||||
|
||||
assert sum(seqlens[-1]) == seqlen
|
||||
assert all(s > 0 for s in seqlens[-1])
|
||||
|
||||
total_entries = batch_size * 10
|
||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
|
||||
dim=0).cuda()
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0).cuda()
|
||||
|
||||
dim = 4
|
||||
dstate = 8
|
||||
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
|
||||
A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)
|
||||
A_ref = A.clone()
|
||||
B_shape = [varBC_groups, dstate, seqlen]
|
||||
B = torch.randn(B_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_B else itype)
|
||||
B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype)
|
||||
B_ref = B.clone()
|
||||
C_shape = [varBC_groups, dstate, seqlen]
|
||||
C = torch.randn(C_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_C else itype)
|
||||
C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
|
||||
C_ref = C.clone()
|
||||
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
|
||||
D_ref = D.clone()
|
||||
z = torch.randn(dim, seqlen, device=device, dtype=itype)
|
||||
z_ref = z.clone()
|
||||
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
|
||||
) if has_delta_bias else None
|
||||
delta_bias = (
|
||||
(0.5 * torch.rand(dim, device=device, dtype=torch.float32))
|
||||
if has_delta_bias
|
||||
else None
|
||||
)
|
||||
u = torch.randn(dim, seqlen, device=device, dtype=itype)
|
||||
u_ref = u.clone()
|
||||
delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype))
|
||||
delta = 0.5 * torch.rand(dim, seqlen, device=device, dtype=itype)
|
||||
delta_ref = delta.clone()
|
||||
out = None
|
||||
out_ref = None
|
||||
|
||||
prev_state_shape = (total_entries, u.shape[0], int(A.shape[1]))
|
||||
prev_state = torch.randn(prev_state_shape,
|
||||
device=u.device,
|
||||
dtype=itype,
|
||||
requires_grad=False)
|
||||
prev_state = torch.randn(
|
||||
prev_state_shape, device=u.device, dtype=itype, requires_grad=False
|
||||
)
|
||||
prev_state_ref = prev_state.clone()
|
||||
state_indices = torch.randperm(total_entries,
|
||||
dtype=torch.int32,
|
||||
device=u.device)[:batch_size]
|
||||
unused_states_bool = torch.ones(total_entries,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[
|
||||
:batch_size
|
||||
]
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
unused_states_bool[state_indices] = False
|
||||
padded_state_indices = torch.concat([
|
||||
state_indices,
|
||||
torch.as_tensor(
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1)
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
has_initial_state = torch.randint(0,
|
||||
2, (cumsum.shape[0] - 1, ),
|
||||
dtype=torch.bool,
|
||||
device=u.device)
|
||||
out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
|
||||
delta_softplus, cumsum, padded_state_indices,
|
||||
has_initial_state)
|
||||
has_initial_state = torch.randint(
|
||||
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=u.device
|
||||
)
|
||||
out = selective_scan_fn(
|
||||
u,
|
||||
prev_state,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias,
|
||||
delta_softplus,
|
||||
cumsum,
|
||||
padded_state_indices,
|
||||
has_initial_state,
|
||||
)
|
||||
outs_ref = []
|
||||
splits = [
|
||||
torch.split(var, seqlens[0], dim=-1)
|
||||
@@ -530,33 +558,46 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
|
||||
delta_softplus=delta_softplus,
|
||||
return_last_state=return_last_state,
|
||||
prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0)
|
||||
if has_initial_state[i] else None,
|
||||
final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(
|
||||
0))
|
||||
if has_initial_state[i]
|
||||
else None,
|
||||
final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(0),
|
||||
)
|
||||
outs_ref.append(out_ref_s)
|
||||
out_ref = torch.cat(outs_ref, dim=-1)[0]
|
||||
|
||||
unpadded_out = out[:, :out_ref[0].shape[-1]]
|
||||
unpadded_out = out[:, : out_ref[0].shape[-1]]
|
||||
print("Output diff max", (unpadded_out - out_ref).max())
|
||||
print("Output diff mean", (unpadded_out - out_ref).mean())
|
||||
print("Output state diff max", (prev_state - prev_state_ref).max())
|
||||
print("Output state diff mean", (prev_state - prev_state_ref).mean())
|
||||
assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol)
|
||||
selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias,
|
||||
delta_softplus, cumsum, padded_state_indices,
|
||||
has_initial_state, prev_state)
|
||||
selective_scan_opcheck_fn(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias,
|
||||
delta_softplus,
|
||||
cumsum,
|
||||
padded_state_indices,
|
||||
has_initial_state,
|
||||
prev_state,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
||||
has_z, itype):
|
||||
def test_selective_state_update_with_batch_indices(
|
||||
with_padding, dim, dstate, has_z, itype
|
||||
):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
@@ -571,17 +612,17 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
||||
total_entries = 10 * batch_size
|
||||
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device)
|
||||
unused_states_bool = torch.ones(total_entries,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
unused_states_bool[state_indices] = False
|
||||
padded_state_indices = torch.concat([
|
||||
state_indices,
|
||||
torch.as_tensor(
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
|
||||
],
|
||||
dim=0)
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
@@ -593,61 +634,60 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].clone()
|
||||
state_before = state.clone()
|
||||
selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x[:batch_size],
|
||||
dt[:batch_size],
|
||||
A,
|
||||
B[:batch_size],
|
||||
C[:batch_size],
|
||||
D=D,
|
||||
z=z[:batch_size],
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out,
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
state_ref,
|
||||
x[:batch_size],
|
||||
dt[:batch_size],
|
||||
A,
|
||||
B[:batch_size],
|
||||
C[:batch_size],
|
||||
D=D,
|
||||
z=z[:batch_size],
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
|
||||
print("Output diff max", (out[:batch_size] - out_ref).max())
|
||||
print("Output diff mean", (out[:batch_size] - out_ref).mean())
|
||||
print("Output state diff max", (state[state_indices, :] - state_ref).max())
|
||||
print("Output state diff mean",
|
||||
(state[state_indices, :] - state_ref).mean())
|
||||
print("Output state diff mean", (state[state_indices, :] - state_ref).mean())
|
||||
# test padded entries stay the same
|
||||
if with_padding:
|
||||
assert torch.equal(state_before[unused_states_bool],
|
||||
state[unused_states_bool])
|
||||
assert torch.equal(x[batch_size + 1:], x[batch_size + 1:])
|
||||
assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:])
|
||||
assert torch.equal(B[batch_size + 1:], B[batch_size + 1:])
|
||||
assert torch.equal(C[batch_size + 1:], C[batch_size + 1:])
|
||||
assert torch.equal(state_before[unused_states_bool], state[unused_states_bool])
|
||||
assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :])
|
||||
assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :])
|
||||
assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :])
|
||||
assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :])
|
||||
|
||||
# test "real" entries
|
||||
assert torch.allclose(state[state_indices, :],
|
||||
state_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("tie_hdim", [False, True])
|
||||
@pytest.mark.parametrize("ngroups", [1, 2, 4])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 4096])
|
||||
def test_selective_state_update_with_heads_with_batch_indices(
|
||||
dim, dstate, ngroups, has_z, tie_hdim, itype):
|
||||
dim, dstate, ngroups, has_z, tie_hdim, itype
|
||||
):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
|
||||
if itype == torch.bfloat16:
|
||||
@@ -659,71 +699,55 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
||||
nheads = dim // headdim
|
||||
|
||||
total_entries = 10 * batch_size
|
||||
state = torch.randn(total_entries,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
dtype=itype,
|
||||
device=device)
|
||||
state = torch.randn(
|
||||
total_entries, nheads, headdim, dstate, dtype=itype, device=device
|
||||
)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device)
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
if not tie_hdim:
|
||||
dt = torch.randn(batch_size,
|
||||
nheads,
|
||||
headdim,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
|
||||
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
|
||||
D = torch.randn(nheads, headdim, device=device)
|
||||
else:
|
||||
dt = repeat(torch.randn(batch_size, nheads, device=device,
|
||||
dtype=itype),
|
||||
"b h -> b h p",
|
||||
p=headdim)
|
||||
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0,
|
||||
"h -> h p",
|
||||
p=headdim)
|
||||
A = repeat(-torch.rand(nheads, device=device) - 1.0,
|
||||
"h -> h p n",
|
||||
p=headdim,
|
||||
n=dstate)
|
||||
dt = repeat(
|
||||
torch.randn(batch_size, nheads, device=device, dtype=itype),
|
||||
"b h -> b h p",
|
||||
p=headdim,
|
||||
)
|
||||
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
|
||||
A = repeat(
|
||||
-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate
|
||||
)
|
||||
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
|
||||
B = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||
C = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].detach().clone()
|
||||
selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out,
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
|
||||
)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
assert torch.allclose(state[state_indices, :],
|
||||
state_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
Reference in New Issue
Block a user