[Kernel][Model] Varlen prefill + Prefill chunking support for mamba kernels and Jamba model (#8533)

This commit is contained in:
Mor Zusman
2024-09-30 00:35:58 +03:00
committed by GitHub
parent 6c9ba48fde
commit f13a07b1f8
13 changed files with 1176 additions and 894 deletions

View File

@@ -3,7 +3,6 @@ from typing import Optional
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
@@ -57,43 +56,72 @@ def causal_conv1d_ref(
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_update_ref(x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None):
def causal_conv1d_update_ref(x,
conv_state,
weight,
bias=None,
activation=None,
cache_seqlens=None):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim)
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
batch, dim = x.shape
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
width = weight.shape[1]
assert conv_state.shape == (batch, dim, width)
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
conv_state.copy_(torch.roll(conv_state, shifts=-1,
dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
out = torch.sum(conv_state * weight, dim=-1) # (B D)
if bias is not None:
out += bias
if cache_seqlens is None:
x_new = torch.cat([conv_state, x], dim=-1).to(
weight.dtype) # (batch, dim, state_len + seqlen)
conv_state.copy_(x_new[:, :, -state_len:])
else:
width_idx = torch.arange(
-(width - 1), 0, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(
-1, dim, -1)
x_new = torch.cat([conv_state.gather(2, width_idx), x],
dim=-1).to(weight.dtype)
copy_idx = torch.arange(
seqlen, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx,
state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
groups=dim)[:, :, -seqlen:]
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
def causal_conv1d_opcheck_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out=None,
cu_seq_len: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
@@ -109,135 +137,93 @@ def causal_conv1d_opcheck_fn(
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
if seq_idx is not None:
assert (initial_states is
None), "initial_states must be None if seq_idx is not None"
assert (not return_final_states
), "If seq_idx is not None, we don't return final_states_out"
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
if initial_states is not None and (initial_states.stride(2) != 1
and initial_states.stride(1) != 1):
initial_states = initial_states.contiguous()
if return_final_states:
assert (
x.stride(1) == 1
), "Only channel-last layout support returning final_states_out"
if final_states_out is not None:
assert (final_states_out.stride(2) == 1
or final_states_out.stride(1) == 1)
else:
batch, dim, seqlen = x.shape
width = weight.shape[1]
final_states_out = torch.empty(batch,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
else:
final_states_out = None
opcheck(torch.ops._C.causal_conv1d_fwd,
(x, weight, bias, seq_idx, initial_states, final_states_out,
activation in ["silu", "swish"]))
opcheck(torch.ops._C.causal_conv1d_fwd, (
x,
weight,
bias,
conv_states,
cu_seq_len,
cache_indices,
has_initial_state,
activation in ["silu", "swish"],
))
@pytest.mark.parametrize("return_final_states", [False, True])
@pytest.mark.parametrize("has_initial_states", [False, True])
@pytest.mark.parametrize("channel_last", [False, True])
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("seqlen", [128, 512, 4096])
@pytest.mark.parametrize('dim', [64, 4096 + 32])
@pytest.mark.parametrize('batch', [1, 2])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
itype, channel_last, has_initial_states,
return_final_states):
if not channel_last and (has_initial_states or return_final_states):
pytest.skip(
"Only channel_last support initial_states or return_final_states")
itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
seed_everything(0)
if not channel_last:
x = torch.randn(batch,
4096 + dim + 64,
seqlen,
device=device,
dtype=itype)[:, 4096:4096 + dim, :]
else:
x = rearrange(
torch.randn(batch,
seqlen,
4096 + dim + 64,
device=device,
dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s")
x = torch.randn(batch, dim, seqlen, device=device,
dtype=itype).contiguous()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
if has_initial_states:
initial_states = torch.randn(batch,
width - 1,
dim,
device=device,
dtype=itype).transpose(1, 2)
else:
initial_states = None
x_ref = x.detach().clone()
weight_ref = weight.detach().clone()
bias_ref = bias.detach().clone() if bias is not None else None
initial_states_ref = initial_states.detach().clone(
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
initial_states_ref = initial_states.clone(
) if initial_states is not None else None
activation = None if not silu_activation else "silu"
out, final_states = causal_conv1d_fn(
x,
weight,
bias,
initial_states=initial_states,
return_final_states=return_final_states,
activation=activation)
out = causal_conv1d_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=return_final_states,
return_final_states=True,
activation=activation)
causal_conv1d_opcheck_fn(x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=return_final_states,
activation=activation)
if return_final_states:
assert final_states is not None and final_states_ref is not None
assert torch.allclose(final_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
if return_final_states:
out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
causal_conv1d_opcheck_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
@pytest.mark.parametrize("batch", [1, 2])
def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
@@ -246,8 +232,9 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
# set seed
seed_everything(0)
batch = 2
x = torch.randn(batch, dim, device=device, dtype=itype)
conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim,
width,
device=device,
@@ -273,9 +260,15 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(
torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation in ["silu", "swish"], None))
opcheck(torch.ops._C.causal_conv1d_update, (
x,
conv_state,
weight,
bias,
activation in ["silu", "swish"],
None,
None,
))
@pytest.mark.parametrize("itype",
@@ -292,16 +285,16 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.random.manual_seed(0)
# set )seed
seed_everything(0)
batch = 64
x = torch.randn(batch, dim, device=device, dtype=itype)
x = torch.randn(batch, dim, 1, device=device, dtype=itype)
total_entries = 10 * batch
conv_state = torch.randn(total_entries,
dim,
width,
width - 1,
device=device,
dtype=itype)
conv_state_indices = torch.randperm(total_entries)[:batch].to(
@@ -332,3 +325,100 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(torch.ops._C.causal_conv1d_update, (
x,
conv_state,
weight,
bias,
activation in ["silu", "swish"],
None,
conv_state_indices,
))
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize('seqlen',
[8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
@pytest.mark.parametrize('dim', [64, 4096])
def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
seed_everything(0)
batch = 1
seqlens = []
nsplits = 3
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0)
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
final_states = torch.randn(nsplits + 1,
dim,
width - 1,
device=x.device,
dtype=x.dtype)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=x.device)
cache_indices = torch.randperm(cumsum.shape[0] - 1,
dtype=torch.int32,
device=x.device)
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
cache_indices, has_initial_states, final_states,
activation)
out_ref = []
out_ref_b = []
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight_ref,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[cache_indices[i]].unsqueeze(
0),
initial_states=final_states_ref[cache_indices[i]].unsqueeze(0)
if has_initial_states[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref = torch.cat(out_ref, dim=0)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print("Output state max diff"
f":{(final_states - final_states_ref).abs().max()}")
print("Output state mean diff"
f":{(final_states - final_states_ref).abs().mean()}")
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
cache_indices, has_initial_states, final_states,
activation)