Categorize tests/kernels/ based on kernel type (#16799)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
437
tests/kernels/mamba/test_causal_conv1d.py
Normal file
437
tests/kernels/mamba/test_causal_conv1d.py
Normal file
@@ -0,0 +1,437 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x,
|
||||
weight.unsqueeze(1),
|
||||
bias,
|
||||
padding=width - 1,
|
||||
groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_update_ref(x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias=None,
|
||||
activation=None,
|
||||
cache_seqlens=None):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the
|
||||
conv_state starting at the index
|
||||
@cache_seqlens % state_len before performing the convolution.
|
||||
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
x = x.unsqueeze(-1)
|
||||
batch, dim, seqlen = x.shape
|
||||
width = weight.shape[1]
|
||||
state_len = conv_state.shape[-1]
|
||||
assert conv_state.shape == (batch, dim, state_len)
|
||||
assert weight.shape == (dim, width)
|
||||
if cache_seqlens is None:
|
||||
x_new = torch.cat([conv_state, x], dim=-1).to(
|
||||
weight.dtype) # (batch, dim, state_len + seqlen)
|
||||
conv_state.copy_(x_new[:, :, -state_len:])
|
||||
else:
|
||||
width_idx = torch.arange(
|
||||
-(width - 1), 0, dtype=torch.long,
|
||||
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
||||
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(
|
||||
-1, dim, -1)
|
||||
x_new = torch.cat([conv_state.gather(2, width_idx), x],
|
||||
dim=-1).to(weight.dtype)
|
||||
copy_idx = torch.arange(
|
||||
seqlen, dtype=torch.long,
|
||||
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
||||
copy_idx = torch.remainder(copy_idx,
|
||||
state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
conv_state.scatter_(2, copy_idx, x)
|
||||
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
|
||||
groups=dim)[:, :, -seqlen:]
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
def causal_conv1d_opcheck_fn(x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cu_seq_len: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
seq_idx: (batch, seqlen)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1), to be written to
|
||||
activation: either None or "silu" or "swish"
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_fwd,
|
||||
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
|
||||
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("has_initial_state", [True, False])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize(
|
||||
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
|
||||
@pytest.mark.parametrize('dim', [64])
|
||||
@pytest.mark.parametrize('batch', [1])
|
||||
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
has_initial_state, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
x = torch.randn(batch, dim, seqlen, device=device,
|
||||
dtype=itype).contiguous()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
if has_initial_state:
|
||||
initial_states = torch.randn(batch,
|
||||
dim,
|
||||
width - 1,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
has_initial_state_tensor = torch.ones(batch,
|
||||
dtype=torch.bool,
|
||||
device=x.device)
|
||||
else:
|
||||
initial_states = None
|
||||
has_initial_state_tensor = None
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
initial_states_ref = initial_states.clone(
|
||||
) if initial_states is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_states=initial_states,
|
||||
has_initial_state=has_initial_state_tensor)
|
||||
out_ref, final_states_ref = causal_conv1d_ref(
|
||||
x_ref,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
initial_states=initial_states_ref,
|
||||
return_final_states=True,
|
||||
activation=activation)
|
||||
if has_initial_state:
|
||||
assert initial_states is not None and final_states_ref is not None
|
||||
assert torch.allclose(initial_states,
|
||||
final_states_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
causal_conv1d_opcheck_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_states=initial_states,
|
||||
has_initial_state=has_initial_state_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
|
||||
itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state.detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
out_ref = causal_conv1d_update_ref(x_ref,
|
||||
conv_state_ref,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_update,
|
||||
(x, conv_state, weight, bias, activation
|
||||
in ["silu", "swish"], None, None, PAD_SLOT_ID))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
||||
@pytest.mark.parametrize("width", [2, 3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
seqlen, has_bias,
|
||||
silu_activation, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
batch_size = 3
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
total_entries = 10 * batch_size
|
||||
|
||||
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device)
|
||||
unused_states_bool = torch.ones(total_entries,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
unused_states_bool[conv_state_indices] = False
|
||||
padded_state_indices = torch.concat([
|
||||
conv_state_indices,
|
||||
torch.as_tensor(
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
|
||||
],
|
||||
dim=0)
|
||||
conv_state = torch.randn(total_entries,
|
||||
dim,
|
||||
width - 1,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
conv_state_for_padding_test = conv_state.clone()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
|
||||
conv_state_ref,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
assert torch.equal(conv_state[unused_states_bool],
|
||||
conv_state_for_padding_test[unused_states_bool])
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_update,
|
||||
(x, conv_state, weight, bias, activation
|
||||
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize(
|
||||
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
|
||||
@pytest.mark.parametrize('dim', [64, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize('with_padding', [True, False])
|
||||
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
silu_activation, itype):
|
||||
device = "cuda"
|
||||
torch.cuda.empty_cache()
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
seqlens = []
|
||||
batch_size = 4
|
||||
if seqlen < 10:
|
||||
batch_size = 1
|
||||
padding = 3 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
nsplits = padded_batch_size - 1
|
||||
|
||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||
seqlens.append(
|
||||
torch.diff(
|
||||
torch.cat(
|
||||
[torch.tensor([-1]), eos_pos,
|
||||
torch.tensor([seqlen - 1])])).tolist())
|
||||
assert sum(seqlens[-1]) == seqlen
|
||||
assert all(s > 0 for s in seqlens[-1])
|
||||
|
||||
total_entries = batch_size * 10
|
||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
|
||||
dim=0)
|
||||
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
|
||||
dtype=itype)[:, 4096:4096 + dim, :]
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
final_states = torch.randn(total_entries,
|
||||
dim,
|
||||
width - 1,
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
final_states_ref = final_states.clone()
|
||||
has_initial_states = torch.randint(0,
|
||||
2, (cumsum.shape[0] - 1, ),
|
||||
dtype=torch.bool,
|
||||
device=x.device)
|
||||
state_indices = torch.randperm(total_entries,
|
||||
dtype=torch.int32,
|
||||
device=x.device)[:batch_size]
|
||||
padded_state_indices = torch.concat([
|
||||
state_indices,
|
||||
torch.as_tensor(
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
||||
padded_state_indices, has_initial_states,
|
||||
final_states, activation, PAD_SLOT_ID)
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
|
||||
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
|
||||
for i in range(len(seqlens[0])):
|
||||
x_s = [v[i].unsqueeze(0) for v in splits][0]
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=final_states_ref[
|
||||
padded_state_indices[i]].unsqueeze(0),
|
||||
initial_states=final_states_ref[padded_state_indices[i]].
|
||||
unsqueeze(0) if has_initial_states[i] else None))
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
|
||||
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(final_states[state_indices],
|
||||
final_states_ref[state_indices],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
||||
padded_state_indices, has_initial_states,
|
||||
final_states, activation)
|
||||
124
tests/kernels/mamba/test_mamba_mixer2.py
Normal file
124
tests/kernels/mamba/test_mamba_mixer2.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [128])
|
||||
@pytest.mark.parametrize(
|
||||
"hidden_size_n_groups",
|
||||
[
|
||||
(64, 1),
|
||||
(64, 2),
|
||||
(64, 4), # hidden_size be divisible by num_gpus
|
||||
(100, 5), # and n_groups must divide hidden_size
|
||||
])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
def test_mixer2_gated_norm_multi_gpu(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size_n_groups: tuple[int, int],
|
||||
dtype: torch.dtype,
|
||||
device: str = 'cuda',
|
||||
):
|
||||
hidden_size, n_groups = hidden_size_n_groups
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
# need to use torch.mp.spawn otherwise will have problems with
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(fn,
|
||||
args=(
|
||||
num_processes,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
n_groups,
|
||||
dtype,
|
||||
device,
|
||||
),
|
||||
nprocs=nprocs)
|
||||
|
||||
run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2)
|
||||
|
||||
|
||||
def mixer2_gated_norm_tensor_parallel(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
n_groups: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# create random weights an inputs
|
||||
weight = torch.rand((hidden_size, ), dtype=dtype, device=device)
|
||||
hidden_states = torch.randn(batch_size, seq_len, hidden_size)
|
||||
gate_states = torch.randn(batch_size, seq_len, hidden_size)
|
||||
|
||||
# create gated-norm with TP
|
||||
mixer = Mixer2RMSNormGated(
|
||||
full_hidden_size=hidden_size,
|
||||
full_n_groups=n_groups,
|
||||
)
|
||||
mixer.weight.weight_loader(mixer.weight, weight) # load
|
||||
|
||||
# create gated-norm without TP to compute reference
|
||||
# - utilize mock patching to disable TP when
|
||||
with (unittest.mock.patch(
|
||||
"vllm.model_executor.layers.mamba.mamba_mixer2."
|
||||
"get_tensor_model_parallel_world_size",
|
||||
return_value=1),
|
||||
unittest.mock.patch(
|
||||
"vllm.model_executor.layers.mamba.mamba_mixer2."
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=0)):
|
||||
mixer_single_gpu = Mixer2RMSNormGated(
|
||||
full_hidden_size=hidden_size,
|
||||
full_n_groups=n_groups,
|
||||
)
|
||||
# assign weight to single-gpu mixer
|
||||
mixer_single_gpu.weight.data = weight
|
||||
|
||||
# generate and compare
|
||||
N = hidden_size // world_size
|
||||
output = mixer(
|
||||
hidden_states[..., local_rank * N:(local_rank + 1) * N],
|
||||
gate_states[..., local_rank * N:(local_rank + 1) * N],
|
||||
)
|
||||
ref_output = mixer_single_gpu(hidden_states, gate_states)
|
||||
torch.allclose(output,
|
||||
ref_output[..., local_rank * N:(local_rank + 1) * N],
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
722
tests/kernels/mamba/test_mamba_ssm.py
Normal file
722
tests/kernels/mamba/test_mamba_ssm.py
Normal file
@@ -0,0 +1,722 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def selective_state_update_ref(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
"""
|
||||
has_heads = state.dim() > 3
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
batch, nheads, dim, dstate = state.shape
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
dt = dt + dt_bias
|
||||
dt = F.softplus(dt) if dt_softplus else dt
|
||||
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") *
|
||||
A) # (batch, nheads, dim, dstate)
|
||||
B = repeat(B, "b g n -> b (g h) n",
|
||||
h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
C = repeat(C, "b g n -> b (g h) n",
|
||||
h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
|
||||
B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
|
||||
state.copy_(state * dA +
|
||||
dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
|
||||
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
||||
if D is not None:
|
||||
out += (x * D).to(out.dtype)
|
||||
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out
|
||||
|
||||
|
||||
def selective_scan_ref(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
return_last_state=False,
|
||||
prev_state=None,
|
||||
final_state_out=None):
|
||||
"""
|
||||
u: r(B D L)
|
||||
delta: r(B D L)
|
||||
A: c(D N) or r(D N)
|
||||
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
D: r(D)
|
||||
z: r(B D L)
|
||||
delta_bias: r(D), fp32
|
||||
prev_state: r(B D N), fp32
|
||||
|
||||
out: r(B D L)
|
||||
last_state (optional): r(B D dstate) or c(B D dstate)
|
||||
"""
|
||||
dtype_in = u.dtype
|
||||
u = u.float()
|
||||
delta = delta.float()
|
||||
if delta_bias is not None:
|
||||
delta = delta + delta_bias[..., None].float()
|
||||
if delta_softplus:
|
||||
delta = F.softplus(delta)
|
||||
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
||||
is_variable_B = B.dim() >= 3
|
||||
is_variable_C = C.dim() >= 3
|
||||
B = B.float()
|
||||
C = C.float()
|
||||
x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
|
||||
ys = []
|
||||
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
||||
if not is_variable_B:
|
||||
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
||||
else:
|
||||
if B.dim() == 3:
|
||||
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
||||
else:
|
||||
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
||||
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
||||
if is_variable_C and C.dim() == 4:
|
||||
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
||||
for i in range(u.shape[2]):
|
||||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||||
if not is_variable_C:
|
||||
y = torch.einsum('bdn,dn->bd', x, C)
|
||||
else:
|
||||
if C.dim() == 3:
|
||||
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
||||
else:
|
||||
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
||||
if i == u.shape[2] - 1:
|
||||
if final_state_out is None:
|
||||
final_state_out = x
|
||||
else:
|
||||
final_state_out.copy_(x)
|
||||
ys.append(y)
|
||||
y = torch.stack(ys, dim=2) # (batch dim L)
|
||||
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
||||
if z is not None:
|
||||
out = out * F.silu(z)
|
||||
out = out.to(dtype=dtype_in)
|
||||
return out if not return_last_state else (out, final_state_out)
|
||||
|
||||
|
||||
def selective_scan_opcheck_fn(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
cu_seq_len=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
ssm_states=None,
|
||||
pad_slot_id=PAD_SLOT_ID):
|
||||
"""if return_last_state is True, returns (out, last_state)
|
||||
last_state has shape (batch, dim, dstate).
|
||||
"""
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
if delta.stride(-1) != 1:
|
||||
delta = delta.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
if B.dim() == 3 and cu_seq_len is None:
|
||||
B = B.unsqueeze(1)
|
||||
if B.dim() == 2 and cu_seq_len is not None:
|
||||
B = B.unsqueeze(0)
|
||||
if C.dim() == 3 and cu_seq_len is None:
|
||||
C = C.unsqueeze(1)
|
||||
if C.dim() == 2 and cu_seq_len is not None:
|
||||
C = C.unsqueeze(0)
|
||||
|
||||
# Disable test_autograd_registration for now as it seems to trigger
|
||||
# a bogus error.
|
||||
opcheck(torch.ops._C.selective_scan_fwd,
|
||||
(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len,
|
||||
cache_indices, has_initial_state, ssm_states, pad_slot_id),
|
||||
test_utils=["test_schema", "test_faketensor"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wtype', [torch.float32])
|
||||
@pytest.mark.parametrize('itype',
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize('has_delta_bias', [True])
|
||||
@pytest.mark.parametrize('delta_softplus', [True])
|
||||
@pytest.mark.parametrize('has_z', [True])
|
||||
@pytest.mark.parametrize('has_D', [True])
|
||||
@pytest.mark.parametrize("varBC_groups", [1, 2])
|
||||
@pytest.mark.parametrize("is_variable_C", [True])
|
||||
@pytest.mark.parametrize("is_variable_B", [True])
|
||||
@pytest.mark.parametrize("scan_chunks", [1, 2, 3])
|
||||
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
has_z, has_delta_bias, delta_softplus, seqlen, itype,
|
||||
wtype, scan_chunks):
|
||||
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
|
||||
pytest.skip() # This config is not applicable
|
||||
device = 'cuda'
|
||||
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 3e-2, 5e-2
|
||||
rtolw, atolw = (1e-3, 1e-3)
|
||||
if has_z: # If we have z, the errors on the weights seem higher
|
||||
rtolw = max(rtolw, rtol)
|
||||
atolw = max(atolw, atol)
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
batch_size = 1
|
||||
dim = 4
|
||||
dstate = 8
|
||||
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
|
||||
A_ref = A.clone()
|
||||
if not is_variable_B:
|
||||
B_shape = [dim, dstate]
|
||||
elif varBC_groups == 1:
|
||||
B_shape = [batch_size, dstate, seqlen]
|
||||
else:
|
||||
B_shape = [batch_size, varBC_groups, dstate, seqlen]
|
||||
B = torch.randn(B_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_B else itype)
|
||||
B_ref = B.clone()
|
||||
if not is_variable_C:
|
||||
C_shape = [dim, dstate]
|
||||
elif varBC_groups == 1:
|
||||
C_shape = [batch_size, dstate, seqlen]
|
||||
else:
|
||||
C_shape = [batch_size, varBC_groups, dstate, seqlen]
|
||||
C = torch.randn(C_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_C else itype)
|
||||
C_ref = C.clone()
|
||||
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
|
||||
D_ref = D.clone()
|
||||
z = torch.randn(batch_size, dim, seqlen, device=device,
|
||||
dtype=itype) if has_z else None
|
||||
z_ref = z.clone() if has_z else None
|
||||
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
|
||||
) if has_delta_bias else None
|
||||
u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
|
||||
u_ref = u.clone()
|
||||
delta = (0.5 *
|
||||
torch.rand(batch_size, dim, seqlen, device=device, dtype=itype))
|
||||
delta_ref = delta.clone()
|
||||
state_shape = (batch_size, u.shape[1], int(A.shape[1]))
|
||||
state = torch.randn(state_shape,
|
||||
device=u.device,
|
||||
dtype=itype,
|
||||
requires_grad=False)
|
||||
state_ref = state.clone()
|
||||
out = None
|
||||
out_ref = None
|
||||
outs = []
|
||||
for c in range(scan_chunks):
|
||||
chunked_prompt_len = seqlen // scan_chunks
|
||||
chunk_start = chunked_prompt_len * c
|
||||
chunk_end = chunked_prompt_len * (c + 1)
|
||||
if c == scan_chunks - 1:
|
||||
chunk_end = seqlen
|
||||
_B = B
|
||||
if is_variable_B:
|
||||
_B = B[..., chunk_start:chunk_end]
|
||||
_C = C
|
||||
if is_variable_B:
|
||||
_C = C[..., chunk_start:chunk_end]
|
||||
_z = z
|
||||
if has_z:
|
||||
assert z is not None
|
||||
_z = z[..., chunk_start:chunk_end]
|
||||
out = selective_scan_fn(
|
||||
u[..., chunk_start:chunk_end],
|
||||
state,
|
||||
delta[..., chunk_start:chunk_end],
|
||||
A,
|
||||
_B,
|
||||
_C,
|
||||
D,
|
||||
z=_z,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
has_initial_state=torch.ones(batch_size,
|
||||
device=u.device,
|
||||
dtype=torch.bool) if c > 0 else None)
|
||||
outs.append(out)
|
||||
if len(outs) > 1:
|
||||
out = torch.cat(outs, dim=-1)
|
||||
|
||||
out_ref, state_ref, *rest = selective_scan_ref(
|
||||
u_ref,
|
||||
delta_ref,
|
||||
A_ref,
|
||||
B_ref,
|
||||
C_ref,
|
||||
D_ref,
|
||||
z=z_ref,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
return_last_state=True)
|
||||
|
||||
assert out is not None and out_ref is not None
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
assert state is not None and state_ref is not None
|
||||
assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol)
|
||||
|
||||
selective_scan_opcheck_fn(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
ssm_states=state)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(batch_size, dstate, device=device)
|
||||
C = torch.randn(batch_size, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state.detach().clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wtype', [torch.float32])
|
||||
@pytest.mark.parametrize('itype', [torch.float32])
|
||||
@pytest.mark.parametrize('seqlen', [1, 128, 129, 256, 512, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize("return_last_state", [True])
|
||||
@pytest.mark.parametrize('has_delta_bias', [True])
|
||||
@pytest.mark.parametrize('delta_softplus', [True])
|
||||
@pytest.mark.parametrize('has_z', [True])
|
||||
@pytest.mark.parametrize('has_D', [True])
|
||||
@pytest.mark.parametrize("varBC_groups", [1, 2])
|
||||
@pytest.mark.parametrize("is_variable_C", [True])
|
||||
@pytest.mark.parametrize("is_variable_B", [True])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [False, True])
|
||||
def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
|
||||
varBC_groups, has_D, has_z, has_delta_bias,
|
||||
delta_softplus, return_last_state, seqlen,
|
||||
itype, wtype):
|
||||
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
|
||||
pytest.skip() # This config is not applicable
|
||||
device = 'cuda'
|
||||
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 3e-2, 5e-2
|
||||
rtolw, atolw = (1e-3, 1e-3)
|
||||
if has_z: # If we have z, the errors on the weights seem higher
|
||||
rtolw = max(rtolw, rtol)
|
||||
atolw = max(atolw, atol)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seqlens = []
|
||||
batch_size = 4
|
||||
if seqlen < 10:
|
||||
batch_size = 1
|
||||
padding = 3 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
|
||||
if with_padding and seqlen < padded_batch_size:
|
||||
pytest.skip()
|
||||
|
||||
nsplits = padded_batch_size - 1
|
||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||
seqlens.append(
|
||||
torch.diff(
|
||||
torch.cat(
|
||||
[torch.tensor([-1]), eos_pos,
|
||||
torch.tensor([seqlen - 1])])).tolist())
|
||||
|
||||
assert sum(seqlens[-1]) == seqlen
|
||||
assert all(s > 0 for s in seqlens[-1])
|
||||
|
||||
total_entries = batch_size * 10
|
||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
|
||||
dim=0).cuda()
|
||||
|
||||
dim = 4
|
||||
dstate = 8
|
||||
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
|
||||
A_ref = A.clone()
|
||||
B_shape = [varBC_groups, dstate, seqlen]
|
||||
B = torch.randn(B_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_B else itype)
|
||||
B_ref = B.clone()
|
||||
C_shape = [varBC_groups, dstate, seqlen]
|
||||
C = torch.randn(C_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_C else itype)
|
||||
C_ref = C.clone()
|
||||
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
|
||||
D_ref = D.clone()
|
||||
z = torch.randn(dim, seqlen, device=device, dtype=itype)
|
||||
z_ref = z.clone()
|
||||
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
|
||||
) if has_delta_bias else None
|
||||
u = torch.randn(dim, seqlen, device=device, dtype=itype)
|
||||
u_ref = u.clone()
|
||||
delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype))
|
||||
delta_ref = delta.clone()
|
||||
out = None
|
||||
out_ref = None
|
||||
|
||||
prev_state_shape = (total_entries, u.shape[0], int(A.shape[1]))
|
||||
prev_state = torch.randn(prev_state_shape,
|
||||
device=u.device,
|
||||
dtype=itype,
|
||||
requires_grad=False)
|
||||
prev_state_ref = prev_state.clone()
|
||||
state_indices = torch.randperm(total_entries,
|
||||
dtype=torch.int32,
|
||||
device=u.device)[:batch_size]
|
||||
unused_states_bool = torch.ones(total_entries,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
unused_states_bool[state_indices] = False
|
||||
padded_state_indices = torch.concat([
|
||||
state_indices,
|
||||
torch.as_tensor(
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
has_initial_state = torch.randint(0,
|
||||
2, (cumsum.shape[0] - 1, ),
|
||||
dtype=torch.bool,
|
||||
device=u.device)
|
||||
out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
|
||||
delta_softplus, cumsum, padded_state_indices,
|
||||
has_initial_state)
|
||||
outs_ref = []
|
||||
splits = [
|
||||
torch.split(var, seqlens[0], dim=-1)
|
||||
for var in (u_ref, delta_ref, B_ref, C_ref, z_ref)
|
||||
]
|
||||
for i in range(len(seqlens[0])):
|
||||
u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_s, _ = selective_scan_ref(
|
||||
u_s,
|
||||
delta_s,
|
||||
A_ref,
|
||||
B_s,
|
||||
C_s,
|
||||
D_ref,
|
||||
z=z_s,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
return_last_state=return_last_state,
|
||||
prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0)
|
||||
if has_initial_state[i] else None,
|
||||
final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(
|
||||
0))
|
||||
outs_ref.append(out_ref_s)
|
||||
out_ref = torch.cat(outs_ref, dim=-1)[0]
|
||||
|
||||
unpadded_out = out[:, :out_ref[0].shape[-1]]
|
||||
print("Output diff max", (unpadded_out - out_ref).max())
|
||||
print("Output diff mean", (unpadded_out - out_ref).mean())
|
||||
print("Output state diff max", (prev_state - prev_state_ref).max())
|
||||
print("Output state diff mean", (prev_state - prev_state_ref).mean())
|
||||
assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol)
|
||||
selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias,
|
||||
delta_softplus, cumsum, padded_state_indices,
|
||||
has_initial_state, prev_state)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
||||
has_z, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-1, 1e-1
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 3
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
total_entries = 10 * batch_size
|
||||
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device)
|
||||
unused_states_bool = torch.ones(total_entries,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
unused_states_bool[state_indices] = False
|
||||
padded_state_indices = torch.concat([
|
||||
state_indices,
|
||||
torch.as_tensor(
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
|
||||
],
|
||||
dim=0)
|
||||
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(padded_batch_size, dstate, device=device)
|
||||
C = torch.randn(padded_batch_size, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].clone()
|
||||
state_before = state.clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x[:batch_size],
|
||||
dt[:batch_size],
|
||||
A,
|
||||
B[:batch_size],
|
||||
C[:batch_size],
|
||||
D=D,
|
||||
z=z[:batch_size],
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
|
||||
print("Output diff max", (out[:batch_size] - out_ref).max())
|
||||
print("Output diff mean", (out[:batch_size] - out_ref).mean())
|
||||
print("Output state diff max", (state[state_indices, :] - state_ref).max())
|
||||
print("Output state diff mean",
|
||||
(state[state_indices, :] - state_ref).mean())
|
||||
# test padded entries stay the same
|
||||
if with_padding:
|
||||
assert torch.equal(state_before[unused_states_bool],
|
||||
state[unused_states_bool])
|
||||
assert torch.equal(x[batch_size + 1:], x[batch_size + 1:])
|
||||
assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:])
|
||||
assert torch.equal(B[batch_size + 1:], B[batch_size + 1:])
|
||||
assert torch.equal(C[batch_size + 1:], C[batch_size + 1:])
|
||||
|
||||
# test "real" entries
|
||||
assert torch.allclose(state[state_indices, :],
|
||||
state_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("tie_hdim", [False, True])
|
||||
@pytest.mark.parametrize("ngroups", [1, 2, 4])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 4096])
|
||||
def test_selective_state_update_with_heads_with_batch_indices(
|
||||
dim, dstate, ngroups, has_z, tie_hdim, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-1, 1e-1
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 3
|
||||
headdim = 64
|
||||
nheads = dim // headdim
|
||||
|
||||
total_entries = 10 * batch_size
|
||||
state = torch.randn(total_entries,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
dtype=itype,
|
||||
device=device)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device)
|
||||
|
||||
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
||||
if not tie_hdim:
|
||||
dt = torch.randn(batch_size,
|
||||
nheads,
|
||||
headdim,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
|
||||
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
|
||||
D = torch.randn(nheads, headdim, device=device)
|
||||
else:
|
||||
dt = repeat(torch.randn(batch_size, nheads, device=device,
|
||||
dtype=itype),
|
||||
"b h -> b h p",
|
||||
p=headdim)
|
||||
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0,
|
||||
"h -> h p",
|
||||
p=headdim)
|
||||
A = repeat(-torch.rand(nheads, device=device) - 1.0,
|
||||
"h -> h p n",
|
||||
p=headdim,
|
||||
n=dstate)
|
||||
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
|
||||
B = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||
C = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].detach().clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
assert torch.allclose(state[state_indices, :],
|
||||
state_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
309
tests/kernels/mamba/test_mamba_ssm_ssd.py
Normal file
309
tests/kernels/mamba/test_mamba_ssm_ssd.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
_seq_idx_to_chunk_indices_offsets)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
|
||||
|
||||
|
||||
# this is the segsum implementation taken from above
|
||||
def segsum(x):
|
||||
"""Calculates segment sum."""
|
||||
T = x.size(-1)
|
||||
x = repeat(x, "... d -> ... d e", e=T)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
|
||||
diagonal=-1)
|
||||
x = x.masked_fill(~mask, 0)
|
||||
x_segsum = torch.cumsum(x, dim=-2)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
|
||||
diagonal=0)
|
||||
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
||||
return x_segsum
|
||||
|
||||
|
||||
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
|
||||
"""
|
||||
Arguments:
|
||||
X: (batch, length, n_heads, d_head)
|
||||
A: (batch, length, n_heads)
|
||||
B: (batch, length, n_heads, d_state)
|
||||
C: (batch, length, n_heads, d_state)
|
||||
Return:
|
||||
Y: (batch, length, n_heads, d_head)
|
||||
"""
|
||||
assert X.dtype == A.dtype == B.dtype == C.dtype
|
||||
assert X.shape[1] % block_len == 0
|
||||
|
||||
# Rearrange into blocks/chunks
|
||||
X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len)
|
||||
for x in (X, A, B, C))
|
||||
|
||||
A = rearrange(A, "b c l h -> b h c l")
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||
L = torch.exp(segsum(A))
|
||||
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
|
||||
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at
|
||||
# chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([initial_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
||||
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
||||
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||
|
||||
# 4. Compute state -> output conversion per chunk
|
||||
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
|
||||
|
||||
# Add output of intra-chunk and inter-chunk terms
|
||||
# (diagonal and off-diagonal blocks)
|
||||
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
|
||||
return Y, final_state
|
||||
|
||||
|
||||
def generate_random_inputs(batch_size,
|
||||
seqlen,
|
||||
n_heads,
|
||||
d_head,
|
||||
itype,
|
||||
device='cuda'):
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device)))
|
||||
dt = F.softplus(
|
||||
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) -
|
||||
4)
|
||||
X = torch.randn((batch_size, seqlen, n_heads, d_head),
|
||||
dtype=itype,
|
||||
device=device)
|
||||
B = torch.randn((batch_size, seqlen, n_heads, d_head),
|
||||
dtype=itype,
|
||||
device=device)
|
||||
C = torch.randn((batch_size, seqlen, n_heads, d_head),
|
||||
dtype=itype,
|
||||
device=device)
|
||||
|
||||
return A, dt, X, B, C
|
||||
|
||||
|
||||
def generate_continous_batched_examples(example_lens_by_batch,
|
||||
num_examples,
|
||||
full_length,
|
||||
last_taken,
|
||||
exhausted,
|
||||
n_heads,
|
||||
d_head,
|
||||
itype,
|
||||
device='cuda'):
|
||||
|
||||
# this function generates a random examples of certain length
|
||||
# and then cut according to "example_lens_by_batch" and feed
|
||||
# them in continuous batches to the kernels
|
||||
|
||||
# generate the full-length example
|
||||
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
|
||||
d_head, itype)
|
||||
|
||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
|
||||
A * dt,
|
||||
B,
|
||||
C,
|
||||
block_len=full_length // 4)
|
||||
|
||||
# internal function that outputs a cont batch of examples
|
||||
# given a tuple of lengths for each example in the batch
|
||||
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
|
||||
# 4 examples from second eg, etc
|
||||
def get_continuous_batch(example_lens: tuple[int, ...]):
|
||||
|
||||
indices = []
|
||||
for i, x in enumerate(example_lens):
|
||||
c = last_taken.get(i, 0)
|
||||
indices.append((c, c + x))
|
||||
last_taken[i] = (c + x) % full_length
|
||||
exhausted[i] = last_taken[i] == 0
|
||||
|
||||
return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)
|
||||
]).unsqueeze(0) for x in (dt, X, B, C))
|
||||
|
||||
# internal function that maps "n" to the appropriate right boundary
|
||||
# value when forming continuous batches from examples of length given
|
||||
# by "full_length".
|
||||
# - e.g., when n > full_length, returns n % full_length
|
||||
# when n == full_length, returns full_length
|
||||
def end_boundary(n: int):
|
||||
return n - ((n - 1) // full_length) * full_length
|
||||
|
||||
IND_E = None
|
||||
for spec in example_lens_by_batch:
|
||||
|
||||
# get the (maybe partial) example seen in this cont batch
|
||||
dt2, X2, B2, C2 = get_continuous_batch(spec)
|
||||
|
||||
# get the metadata
|
||||
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
|
||||
seq_idx = torch.zeros(cu_seqlens[-1],
|
||||
dtype=torch.int32,
|
||||
device=cu_seqlens.device)
|
||||
for i, (srt, end) in enumerate(zip(
|
||||
cu_seqlens,
|
||||
cu_seqlens[1:],
|
||||
)):
|
||||
seq_idx[srt:end] = i
|
||||
|
||||
# for cont batch
|
||||
if IND_E is None:
|
||||
IND_S = [0 for _ in range(len(spec))]
|
||||
else:
|
||||
IND_S = [x % full_length for x in IND_E]
|
||||
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
|
||||
|
||||
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
|
||||
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
|
||||
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
|
||||
@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)])
|
||||
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||
itype):
|
||||
|
||||
# this tests the kernels on a single example (no batching)
|
||||
|
||||
# set seed
|
||||
batch_size = 1 # batch_size
|
||||
# ssd_minimal_discrete requires chunk_size divide seqlen
|
||||
# - this is only required for generating the reference seqs,
|
||||
# it is not an operational limitation.
|
||||
seqlen, chunk_size = seq_len_chunk_size
|
||||
|
||||
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads,
|
||||
d_head, itype)
|
||||
|
||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
||||
B, C, chunk_size)
|
||||
|
||||
Y, final_state = mamba_chunk_scan_combined(X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
return_final_states=True)
|
||||
|
||||
# just test the last in sequence
|
||||
torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3)
|
||||
|
||||
# just test the last head
|
||||
# NOTE, in the kernel we always cast states to fp32
|
||||
torch.allclose(final_state[:, -1],
|
||||
final_state_min[:, -1].to(torch.float32),
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
|
||||
@pytest.mark.parametrize("n_heads", [4, 8, 13])
|
||||
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
|
||||
@pytest.mark.parametrize(
|
||||
"seq_len_chunk_size_cases",
|
||||
[
|
||||
|
||||
# small-ish chunk_size (8)
|
||||
(64, 8, 2, [(64, 32), (64, 32)]),
|
||||
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
|
||||
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
|
||||
(64, 8, 2, [(4, 4), (4, 4), (4, 4),
|
||||
(4, 4)]), # chunk_size larger than cont batches
|
||||
(64, 8, 5, [
|
||||
(64, 32, 16, 8, 8),
|
||||
(8, 16, 32, 16, 8),
|
||||
(8, 8, 16, 32, 16),
|
||||
]), # mode examples with varied lengths
|
||||
|
||||
# odd chunk_size
|
||||
(64, 29, 2, [(11, 4), (13, 23), (19, 22),
|
||||
(21, 15)]), # irregular sizes
|
||||
|
||||
# large-ish chunk_size (256)
|
||||
(64, 256, 1, [(5, ), (1, ), (1, ),
|
||||
(1, )]), # irregular sizes with small sequences
|
||||
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
|
||||
(1, 2)]), # irregular sizes with small sequences
|
||||
])
|
||||
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
itype):
|
||||
|
||||
# this test with multiple examples in a continuous batch
|
||||
# (i.e. chunked prefill)
|
||||
|
||||
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
|
||||
|
||||
# hold state during the cutting process so we know if an
|
||||
# example has been exhausted and needs to cycle
|
||||
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
|
||||
states = None
|
||||
for Y_min, cu_seqlens, seq_idx, (A, dt, X, B,
|
||||
C) in generate_continous_batched_examples(
|
||||
cases, num_examples, seqlen,
|
||||
last_taken, exhausted, n_heads,
|
||||
d_head, itype):
|
||||
|
||||
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
|
||||
seq_idx, chunk_size)
|
||||
|
||||
Y, new_states = mamba_chunk_scan_combined(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=states,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
for i in range(num_examples):
|
||||
|
||||
# just test one dim and dstate
|
||||
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
|
||||
Y_min_eg = Y_min[i][:, 0, 0]
|
||||
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)
|
||||
|
||||
# update states
|
||||
states = new_states
|
||||
for i, clear in exhausted.items():
|
||||
if clear:
|
||||
states[i].fill_(0.)
|
||||
exhausted[i] = False
|
||||
Reference in New Issue
Block a user