[Bugfix] Use null block (0) for padded block table entries (#35431)
Signed-off-by: SandishKumarHN <sandish@fb.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -21,7 +21,7 @@ struct SSMParamsBase {
|
|||||||
int dim_ngroups_ratio;
|
int dim_ngroups_ratio;
|
||||||
bool is_variable_B;
|
bool is_variable_B;
|
||||||
bool is_variable_C;
|
bool is_variable_C;
|
||||||
int64_t pad_slot_id;
|
int64_t null_block_id;
|
||||||
|
|
||||||
bool delta_softplus;
|
bool delta_softplus;
|
||||||
bool cache_enabled;
|
bool cache_enabled;
|
||||||
|
|||||||
@@ -118,9 +118,17 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
|
|
||||||
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||||
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
int cache_index;
|
||||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
if (cache_indices == nullptr) {
|
||||||
if (cache_index == params.pad_slot_id){
|
cache_index = batch_id;
|
||||||
|
} else if (params.cache_enabled) {
|
||||||
|
const int* initial_state_idx = reinterpret_cast<const int*>(params.initial_state_idx_ptr);
|
||||||
|
cache_index = cache_indices[batch_id * params.cache_indices_stride + initial_state_idx[batch_id]];
|
||||||
|
} else {
|
||||||
|
cache_index = cache_indices[batch_id];
|
||||||
|
}
|
||||||
|
// Skip batch entries whose cache index maps to the null block (padding).
|
||||||
|
if (cache_indices != nullptr && cache_index == params.null_block_id){
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
|
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
|
||||||
@@ -527,7 +535,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
const std::optional<at::Tensor>& cache_indices,
|
const std::optional<at::Tensor>& cache_indices,
|
||||||
const std::optional<at::Tensor>& has_initial_state,
|
const std::optional<at::Tensor>& has_initial_state,
|
||||||
bool varlen,
|
bool varlen,
|
||||||
int64_t pad_slot_id,
|
int64_t null_block_id,
|
||||||
int64_t block_size,
|
int64_t block_size,
|
||||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||||
@@ -544,7 +552,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
params.dstate = dstate;
|
params.dstate = dstate;
|
||||||
params.n_groups = n_groups;
|
params.n_groups = n_groups;
|
||||||
params.dim_ngroups_ratio = dim / n_groups;
|
params.dim_ngroups_ratio = dim / n_groups;
|
||||||
params.pad_slot_id = pad_slot_id;
|
params.null_block_id = null_block_id;
|
||||||
|
|
||||||
params.delta_softplus = delta_softplus;
|
params.delta_softplus = delta_softplus;
|
||||||
|
|
||||||
@@ -658,7 +666,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
const torch::Tensor &ssm_states,
|
const torch::Tensor &ssm_states,
|
||||||
// used to identify padding entries if cache_indices provided
|
// used to identify padding entries if cache_indices provided
|
||||||
// in case of padding, the kernel will return early
|
// in case of padding, the kernel will return early
|
||||||
int64_t pad_slot_id,
|
int64_t null_block_id,
|
||||||
int64_t block_size,
|
int64_t block_size,
|
||||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||||
@@ -805,7 +813,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
cache_indices,
|
cache_indices,
|
||||||
has_initial_state,
|
has_initial_state,
|
||||||
varlen,
|
varlen,
|
||||||
pad_slot_id,
|
null_block_id,
|
||||||
block_size,
|
block_size,
|
||||||
block_idx_first_scheduled_token,
|
block_idx_first_scheduled_token,
|
||||||
block_idx_last_scheduled_token,
|
block_idx_last_scheduled_token,
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ void selective_scan_fwd(
|
|||||||
const std::optional<torch::Tensor>& query_start_loc,
|
const std::optional<torch::Tensor>& query_start_loc,
|
||||||
const std::optional<torch::Tensor>& cache_indices,
|
const std::optional<torch::Tensor>& cache_indices,
|
||||||
const std::optional<torch::Tensor>& has_initial_state,
|
const std::optional<torch::Tensor>& has_initial_state,
|
||||||
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
|
const torch::Tensor& ssm_states, int64_t null_block_id, int64_t block_size,
|
||||||
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
||||||
const std::optional<torch::Tensor>& initial_state_idx,
|
const std::optional<torch::Tensor>& initial_state_idx,
|
||||||
|
|||||||
@@ -556,7 +556,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"Tensor? cache_indices,"
|
"Tensor? cache_indices,"
|
||||||
"Tensor? has_initial_state,"
|
"Tensor? has_initial_state,"
|
||||||
"Tensor! ssm_states,"
|
"Tensor! ssm_states,"
|
||||||
"int pad_slot_id,"
|
"int null_block_id,"
|
||||||
"int block_size,"
|
"int block_size,"
|
||||||
"Tensor? block_idx_first_scheduled_token,"
|
"Tensor? block_idx_first_scheduled_token,"
|
||||||
"Tensor? block_idx_last_scheduled_token,"
|
"Tensor? block_idx_last_scheduled_token,"
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
|||||||
causal_conv1d_update,
|
causal_conv1d_update,
|
||||||
)
|
)
|
||||||
from vllm.utils.torch_utils import set_random_seed
|
from vllm.utils.torch_utils import set_random_seed
|
||||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
|
||||||
|
|
||||||
|
|
||||||
def causal_conv1d_ref(
|
def causal_conv1d_ref(
|
||||||
@@ -122,7 +122,7 @@ def causal_conv1d_opcheck_fn(
|
|||||||
has_initial_state: torch.Tensor | None = None,
|
has_initial_state: torch.Tensor | None = None,
|
||||||
conv_states: torch.Tensor | None = None,
|
conv_states: torch.Tensor | None = None,
|
||||||
activation: str | None = "silu",
|
activation: str | None = "silu",
|
||||||
pad_slot_id: int = PAD_SLOT_ID,
|
null_block_id: int = NULL_BLOCK_ID,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
x: (batch, dim, seqlen)
|
x: (batch, dim, seqlen)
|
||||||
@@ -158,15 +158,16 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
|
|||||||
batch = 2
|
batch = 2
|
||||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
||||||
x_ref = x.clone()
|
x_ref = x.clone()
|
||||||
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
|
# +1 entry to reserve index 0 as null block
|
||||||
|
conv_state = torch.randn(batch + 1, dim, width - 1, device=device, dtype=itype)
|
||||||
|
|
||||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||||
conv_state_ref = conv_state.detach().clone()
|
# Start indices from 1, skipping null block at index 0
|
||||||
|
conv_state_indices = torch.arange(1, batch + 1, dtype=torch.int32, device=device)
|
||||||
|
conv_state_ref = conv_state[conv_state_indices].detach().clone()
|
||||||
activation = None if not silu_activation else "silu"
|
activation = None if not silu_activation else "silu"
|
||||||
|
|
||||||
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
out = causal_conv1d_update(
|
out = causal_conv1d_update(
|
||||||
x,
|
x,
|
||||||
conv_state,
|
conv_state,
|
||||||
@@ -179,7 +180,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
|
|||||||
x_ref, conv_state_ref, weight, bias, activation=activation
|
x_ref, conv_state_ref, weight, bias, activation=activation
|
||||||
)
|
)
|
||||||
|
|
||||||
assert torch.equal(conv_state, conv_state_ref)
|
assert torch.equal(conv_state[conv_state_indices], conv_state_ref)
|
||||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
@@ -215,7 +216,8 @@ def test_causal_conv1d_update_with_batch_gather(
|
|||||||
|
|
||||||
x_ref = x.clone()
|
x_ref = x.clone()
|
||||||
|
|
||||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
# +1 to exclude index 0 (null block)
|
||||||
|
conv_state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
|
||||||
dtype=torch.int32, device=device
|
dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||||
@@ -223,7 +225,9 @@ def test_causal_conv1d_update_with_batch_gather(
|
|||||||
padded_state_indices = torch.concat(
|
padded_state_indices = torch.concat(
|
||||||
[
|
[
|
||||||
conv_state_indices,
|
conv_state_indices,
|
||||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
torch.as_tensor(
|
||||||
|
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||||
|
),
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
@@ -248,7 +252,6 @@ def test_causal_conv1d_update_with_batch_gather(
|
|||||||
bias,
|
bias,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
conv_state_indices=padded_state_indices,
|
conv_state_indices=padded_state_indices,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
|
||||||
)
|
)
|
||||||
out_ref = causal_conv1d_update_ref(
|
out_ref = causal_conv1d_update_ref(
|
||||||
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
|
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
|
||||||
@@ -317,13 +320,19 @@ def test_causal_conv1d_varlen(
|
|||||||
has_initial_states = torch.randint(
|
has_initial_states = torch.randint(
|
||||||
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
|
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
|
||||||
)
|
)
|
||||||
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
|
# +1 to exclude index 0 (null block)
|
||||||
|
state_indices = (
|
||||||
|
torch.randperm(total_entries - 1, dtype=torch.int32, device=x.device)[
|
||||||
:batch_size
|
:batch_size
|
||||||
]
|
]
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
padded_state_indices = torch.concat(
|
padded_state_indices = torch.concat(
|
||||||
[
|
[
|
||||||
state_indices,
|
state_indices,
|
||||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
torch.as_tensor(
|
||||||
|
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||||
|
),
|
||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
@@ -336,7 +345,6 @@ def test_causal_conv1d_varlen(
|
|||||||
cache_indices=padded_state_indices,
|
cache_indices=padded_state_indices,
|
||||||
has_initial_state=has_initial_states,
|
has_initial_state=has_initial_states,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
out_ref = []
|
out_ref = []
|
||||||
@@ -345,7 +353,7 @@ def test_causal_conv1d_varlen(
|
|||||||
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
|
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
|
||||||
for i in range(len(seqlens[0])):
|
for i in range(len(seqlens[0])):
|
||||||
x_s = [v[i].unsqueeze(0) for v in splits][0]
|
x_s = [v[i].unsqueeze(0) for v in splits][0]
|
||||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
if padded_state_indices[i] == NULL_BLOCK_ID:
|
||||||
continue
|
continue
|
||||||
out_ref_b.append(
|
out_ref_b.append(
|
||||||
causal_conv1d_ref(
|
causal_conv1d_ref(
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import set_random_seed
|
from vllm.utils.torch_utils import set_random_seed
|
||||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
|
||||||
|
|
||||||
|
|
||||||
def selective_state_update_ref(
|
def selective_state_update_ref(
|
||||||
@@ -179,7 +179,7 @@ def selective_scan_opcheck_fn(
|
|||||||
cache_indices=None,
|
cache_indices=None,
|
||||||
has_initial_state=None,
|
has_initial_state=None,
|
||||||
ssm_states=None,
|
ssm_states=None,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
null_block_id=NULL_BLOCK_ID,
|
||||||
block_size=2048,
|
block_size=2048,
|
||||||
block_idx_first_scheduled_token=None,
|
block_idx_first_scheduled_token=None,
|
||||||
block_idx_last_scheduled_token=None,
|
block_idx_last_scheduled_token=None,
|
||||||
@@ -229,7 +229,7 @@ def selective_scan_opcheck_fn(
|
|||||||
cache_indices,
|
cache_indices,
|
||||||
has_initial_state,
|
has_initial_state,
|
||||||
ssm_states,
|
ssm_states,
|
||||||
pad_slot_id,
|
null_block_id,
|
||||||
block_size,
|
block_size,
|
||||||
block_idx_first_scheduled_token,
|
block_idx_first_scheduled_token,
|
||||||
block_idx_last_scheduled_token,
|
block_idx_last_scheduled_token,
|
||||||
@@ -351,7 +351,6 @@ def test_selective_scan(
|
|||||||
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
|
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
|
||||||
if c > 0
|
if c > 0
|
||||||
else None,
|
else None,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
|
||||||
block_size=2048,
|
block_size=2048,
|
||||||
block_idx_first_scheduled_token=None,
|
block_idx_first_scheduled_token=None,
|
||||||
block_idx_last_scheduled_token=None,
|
block_idx_last_scheduled_token=None,
|
||||||
@@ -652,15 +651,21 @@ def test_selective_scan_varlen(
|
|||||||
prev_state_shape, device=u.device, dtype=itype, requires_grad=False
|
prev_state_shape, device=u.device, dtype=itype, requires_grad=False
|
||||||
)
|
)
|
||||||
prev_state_ref = prev_state.clone()
|
prev_state_ref = prev_state.clone()
|
||||||
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[
|
# +1 to exclude index 0 (null block)
|
||||||
|
state_indices = (
|
||||||
|
torch.randperm(total_entries - 1, dtype=torch.int32, device=u.device)[
|
||||||
:batch_size
|
:batch_size
|
||||||
]
|
]
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||||
unused_states_bool[state_indices] = False
|
unused_states_bool[state_indices] = False
|
||||||
padded_state_indices = torch.concat(
|
padded_state_indices = torch.concat(
|
||||||
[
|
[
|
||||||
state_indices,
|
state_indices,
|
||||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
torch.as_tensor(
|
||||||
|
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||||
|
),
|
||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
@@ -690,7 +695,7 @@ def test_selective_scan_varlen(
|
|||||||
]
|
]
|
||||||
for i in range(len(seqlens[0])):
|
for i in range(len(seqlens[0])):
|
||||||
u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
|
u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
|
||||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
if padded_state_indices[i] == NULL_BLOCK_ID:
|
||||||
continue
|
continue
|
||||||
out_ref_s, _ = selective_scan_ref(
|
out_ref_s, _ = selective_scan_ref(
|
||||||
u_s,
|
u_s,
|
||||||
@@ -758,7 +763,8 @@ def test_selective_state_update_with_batch_indices(
|
|||||||
padded_batch_size = batch_size + padding
|
padded_batch_size = batch_size + padding
|
||||||
total_entries = 10 * batch_size
|
total_entries = 10 * batch_size
|
||||||
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
|
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
|
||||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
# +1 to exclude index 0 (null block)
|
||||||
|
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
|
||||||
dtype=torch.int32, device=device
|
dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||||
@@ -766,7 +772,9 @@ def test_selective_state_update_with_batch_indices(
|
|||||||
padded_state_indices = torch.concat(
|
padded_state_indices = torch.concat(
|
||||||
[
|
[
|
||||||
state_indices,
|
state_indices,
|
||||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
torch.as_tensor(
|
||||||
|
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||||
|
),
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
@@ -793,7 +801,6 @@ def test_selective_state_update_with_batch_indices(
|
|||||||
dt_bias=dt_bias,
|
dt_bias=dt_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
state_batch_indices=padded_state_indices,
|
state_batch_indices=padded_state_indices,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
|
||||||
out=out,
|
out=out,
|
||||||
)
|
)
|
||||||
out_ref = selective_state_update_ref(
|
out_ref = selective_state_update_ref(
|
||||||
@@ -849,7 +856,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
|||||||
state = torch.randn(
|
state = torch.randn(
|
||||||
total_entries, nheads, headdim, dstate, dtype=itype, device=device
|
total_entries, nheads, headdim, dstate, dtype=itype, device=device
|
||||||
)
|
)
|
||||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
# +1 to exclude index 0 (null block)
|
||||||
|
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
|
||||||
dtype=torch.int32, device=device
|
dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -887,7 +895,6 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
|||||||
dt_bias=dt_bias,
|
dt_bias=dt_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
state_batch_indices=state_indices,
|
state_batch_indices=state_indices,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
|
||||||
out=out,
|
out=out,
|
||||||
)
|
)
|
||||||
out_ref = selective_state_update_ref(
|
out_ref = selective_state_update_ref(
|
||||||
@@ -935,17 +942,18 @@ def test_selective_state_update_with_num_accepted_tokens(
|
|||||||
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
|
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
|
||||||
|
|
||||||
state_batch_indices = torch.full(
|
state_batch_indices = torch.full(
|
||||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
# Start from 1 to exclude null block at index 0
|
||||||
initial_state_slots = torch.randint(
|
initial_state_slots = torch.randint(
|
||||||
0, 15, (batch_size,), device=device, dtype=torch.int32
|
1, 15, (batch_size,), device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
for seq_idx in range(batch_size):
|
for seq_idx in range(batch_size):
|
||||||
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
||||||
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
|
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
|
||||||
|
|
||||||
dst_state_batch_indices = torch.full(
|
dst_state_batch_indices = torch.full(
|
||||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
slot_offset = 15
|
slot_offset = 15
|
||||||
dst_slots_map = {}
|
dst_slots_map = {}
|
||||||
@@ -1013,7 +1021,6 @@ def test_selective_state_update_with_num_accepted_tokens(
|
|||||||
state_batch_indices=state_batch_indices,
|
state_batch_indices=state_batch_indices,
|
||||||
dst_state_batch_indices=dst_state_batch_indices,
|
dst_state_batch_indices=dst_state_batch_indices,
|
||||||
num_accepted_tokens=num_accepted_tokens,
|
num_accepted_tokens=num_accepted_tokens,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||||
@@ -1061,18 +1068,19 @@ def test_selective_state_update_varlen_with_num_accepted(
|
|||||||
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
|
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
|
||||||
|
|
||||||
state_batch_indices = torch.full(
|
state_batch_indices = torch.full(
|
||||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Start from 1 to exclude null block at index 0
|
||||||
initial_state_slots = torch.randint(
|
initial_state_slots = torch.randint(
|
||||||
0, 15, (batch_size,), device=device, dtype=torch.int32
|
1, 15, (batch_size,), device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
for seq_idx in range(batch_size):
|
for seq_idx in range(batch_size):
|
||||||
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
||||||
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
|
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
|
||||||
|
|
||||||
dst_state_batch_indices = torch.full(
|
dst_state_batch_indices = torch.full(
|
||||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
slot_offset = 15
|
slot_offset = 15
|
||||||
@@ -1138,7 +1146,6 @@ def test_selective_state_update_varlen_with_num_accepted(
|
|||||||
state_batch_indices=state_batch_indices,
|
state_batch_indices=state_batch_indices,
|
||||||
dst_state_batch_indices=dst_state_batch_indices,
|
dst_state_batch_indices=dst_state_batch_indices,
|
||||||
num_accepted_tokens=num_accepted_tokens,
|
num_accepted_tokens=num_accepted_tokens,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for seq_idx in range(batch_size):
|
for seq_idx in range(batch_size):
|
||||||
|
|||||||
@@ -2062,7 +2062,7 @@ def selective_scan_fwd(
|
|||||||
cache_indices: torch.Tensor | None,
|
cache_indices: torch.Tensor | None,
|
||||||
has_initial_state: torch.Tensor | None,
|
has_initial_state: torch.Tensor | None,
|
||||||
ssm_states: torch.Tensor,
|
ssm_states: torch.Tensor,
|
||||||
pad_slot_id: int,
|
null_block_id: int,
|
||||||
block_size: int = 1024,
|
block_size: int = 1024,
|
||||||
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
||||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||||
@@ -2084,7 +2084,7 @@ def selective_scan_fwd(
|
|||||||
cache_indices,
|
cache_indices,
|
||||||
has_initial_state,
|
has_initial_state,
|
||||||
ssm_states,
|
ssm_states,
|
||||||
pad_slot_id,
|
null_block_id,
|
||||||
block_size,
|
block_size,
|
||||||
block_idx_first_scheduled_token,
|
block_idx_first_scheduled_token,
|
||||||
block_idx_last_scheduled_token,
|
block_idx_last_scheduled_token,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import torch
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -602,6 +603,7 @@ def _linear_attn_decode_kernel(
|
|||||||
cache_h_stride,
|
cache_h_stride,
|
||||||
cache_d0_stride,
|
cache_d0_stride,
|
||||||
cache_d1_stride,
|
cache_d1_stride,
|
||||||
|
pad_slot_id: tl.constexpr,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -616,8 +618,8 @@ def _linear_attn_decode_kernel(
|
|||||||
# Load slot index for the current batch
|
# Load slot index for the current batch
|
||||||
slot_id = tl.load(slot_idx + pid_b).to(tl.int64)
|
slot_id = tl.load(slot_idx + pid_b).to(tl.int64)
|
||||||
|
|
||||||
# Skip if slot_id is -1 (padding)
|
# Skip if slot_id is PAD_SLOT_ID (padding)
|
||||||
if slot_id == -1:
|
if slot_id == pad_slot_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
batch_id = pid_b
|
batch_id = pid_b
|
||||||
@@ -727,6 +729,7 @@ def linear_decode_forward_triton(
|
|||||||
cache_h_stride,
|
cache_h_stride,
|
||||||
cache_d0_stride,
|
cache_d0_stride,
|
||||||
cache_d1_stride,
|
cache_d1_stride,
|
||||||
|
pad_slot_id=PAD_SLOT_ID,
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID, PAD_SLOT_ID
|
||||||
|
|
||||||
|
|
||||||
@triton.jit()
|
@triton.jit()
|
||||||
@@ -49,12 +49,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|||||||
stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M
|
stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M
|
||||||
# others
|
# others
|
||||||
pad_slot_id: tl.constexpr,
|
pad_slot_id: tl.constexpr,
|
||||||
|
null_block_id: tl.constexpr,
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
HAS_BIAS: tl.constexpr,
|
HAS_BIAS: tl.constexpr,
|
||||||
KERNEL_WIDTH: tl.constexpr,
|
KERNEL_WIDTH: tl.constexpr,
|
||||||
SILU_ACTIVATION: tl.constexpr,
|
SILU_ACTIVATION: tl.constexpr,
|
||||||
IS_APC_ENABLED: tl.constexpr,
|
IS_APC_ENABLED: tl.constexpr,
|
||||||
USE_PAD_SLOT: tl.constexpr,
|
HAS_NULL_BLOCK: tl.constexpr,
|
||||||
NP2_STATELEN: tl.constexpr,
|
NP2_STATELEN: tl.constexpr,
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
@@ -133,9 +134,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|||||||
conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index
|
conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index
|
||||||
).to(tl.int64)
|
).to(tl.int64)
|
||||||
|
|
||||||
if USE_PAD_SLOT: # noqa
|
if HAS_NULL_BLOCK: # noqa
|
||||||
if conv_states_input_coord == pad_slot_id:
|
if conv_states_input_coord == null_block_id:
|
||||||
# not processing as this is not the actual sequence
|
# not processing as this is a null block (padding)
|
||||||
return
|
return
|
||||||
conv_states_base = (
|
conv_states_base = (
|
||||||
conv_states_ptr
|
conv_states_ptr
|
||||||
@@ -475,6 +476,7 @@ def causal_conv1d_fn(
|
|||||||
has_initial_state: torch.Tensor | None = None,
|
has_initial_state: torch.Tensor | None = None,
|
||||||
activation: str | None = "silu",
|
activation: str | None = "silu",
|
||||||
pad_slot_id: int = PAD_SLOT_ID,
|
pad_slot_id: int = PAD_SLOT_ID,
|
||||||
|
null_block_id: int = NULL_BLOCK_ID,
|
||||||
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
||||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||||
initial_state_idx: torch.Tensor | None = None,
|
initial_state_idx: torch.Tensor | None = None,
|
||||||
@@ -730,12 +732,13 @@ def causal_conv1d_fn(
|
|||||||
block_size_to_align // BLOCK_M,
|
block_size_to_align // BLOCK_M,
|
||||||
# others
|
# others
|
||||||
pad_slot_id,
|
pad_slot_id,
|
||||||
|
null_block_id,
|
||||||
# META
|
# META
|
||||||
HAS_BIAS=bias is not None,
|
HAS_BIAS=bias is not None,
|
||||||
KERNEL_WIDTH=width,
|
KERNEL_WIDTH=width,
|
||||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||||
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
|
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
|
||||||
USE_PAD_SLOT=pad_slot_id is not None,
|
HAS_NULL_BLOCK=null_block_id is not None,
|
||||||
NP2_STATELEN=np2_statelen,
|
NP2_STATELEN=np2_statelen,
|
||||||
# launch_cooperative_grid=True
|
# launch_cooperative_grid=True
|
||||||
BLOCK_M=BLOCK_M,
|
BLOCK_M=BLOCK_M,
|
||||||
@@ -778,7 +781,7 @@ def _causal_conv1d_update_kernel(
|
|||||||
stride_o_dim: tl.constexpr,
|
stride_o_dim: tl.constexpr,
|
||||||
stride_o_token: tl.constexpr,
|
stride_o_token: tl.constexpr,
|
||||||
# others
|
# others
|
||||||
pad_slot_id: tl.constexpr,
|
null_block_id: tl.constexpr,
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
HAS_BIAS: tl.constexpr,
|
HAS_BIAS: tl.constexpr,
|
||||||
KERNEL_WIDTH: tl.constexpr,
|
KERNEL_WIDTH: tl.constexpr,
|
||||||
@@ -787,7 +790,7 @@ def _causal_conv1d_update_kernel(
|
|||||||
IS_APC_ENABLED: tl.constexpr,
|
IS_APC_ENABLED: tl.constexpr,
|
||||||
IS_SPEC_DECODING: tl.constexpr,
|
IS_SPEC_DECODING: tl.constexpr,
|
||||||
NP2_STATELEN: tl.constexpr,
|
NP2_STATELEN: tl.constexpr,
|
||||||
USE_PAD_SLOT: tl.constexpr,
|
HAS_NULL_BLOCK: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
):
|
):
|
||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
@@ -811,8 +814,8 @@ def _causal_conv1d_update_kernel(
|
|||||||
conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init
|
conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init
|
||||||
).to(tl.int64)
|
).to(tl.int64)
|
||||||
|
|
||||||
if USE_PAD_SLOT: # noqa
|
if HAS_NULL_BLOCK: # noqa
|
||||||
if conv_states_input_coord == pad_slot_id:
|
if conv_states_input_coord == null_block_id:
|
||||||
# not processing as this is not the actual sequence
|
# not processing as this is not the actual sequence
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1076,7 +1079,7 @@ def causal_conv1d_update(
|
|||||||
num_accepted_tokens: torch.Tensor | None = None,
|
num_accepted_tokens: torch.Tensor | None = None,
|
||||||
query_start_loc: torch.Tensor | None = None,
|
query_start_loc: torch.Tensor | None = None,
|
||||||
max_query_len: int = -1,
|
max_query_len: int = -1,
|
||||||
pad_slot_id: int = PAD_SLOT_ID,
|
null_block_id: int = NULL_BLOCK_ID,
|
||||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||||
initial_state_idx: torch.Tensor | None = None,
|
initial_state_idx: torch.Tensor | None = None,
|
||||||
validate_data=False,
|
validate_data=False,
|
||||||
@@ -1111,16 +1114,16 @@ def causal_conv1d_update(
|
|||||||
max_query_len: int
|
max_query_len: int
|
||||||
If query_start_loc is not None, this indicates the maximum query
|
If query_start_loc is not None, this indicates the maximum query
|
||||||
length in the batch.
|
length in the batch.
|
||||||
pad_slot_id: int
|
null_block_id: int
|
||||||
if conv_state_indices is passed, lets the kernel identify padded
|
Block ID used to identify padded entries in
|
||||||
entries that will not be processed,
|
conv_state_indices. Block 0 is the null block.
|
||||||
for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
for example: conv_state_indices = [null_block_id, 1, 20, null_block_id]
|
||||||
in this case, the kernel will not process entries at
|
in this case, the kernel will not process entries at
|
||||||
indices 0 and 3
|
indices 0 and 3
|
||||||
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
|
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
|
||||||
"""
|
"""
|
||||||
if validate_data:
|
if validate_data:
|
||||||
assert pad_slot_id is not None
|
assert null_block_id is not None
|
||||||
assert x.stride(1) == 1
|
assert x.stride(1) == 1
|
||||||
if isinstance(activation, bool):
|
if isinstance(activation, bool):
|
||||||
activation = "silu" if activation is True else None
|
activation = "silu" if activation is True else None
|
||||||
@@ -1225,7 +1228,7 @@ def causal_conv1d_update(
|
|||||||
stride_o_dim,
|
stride_o_dim,
|
||||||
stride_o_token,
|
stride_o_token,
|
||||||
# others
|
# others
|
||||||
pad_slot_id,
|
null_block_id,
|
||||||
# META
|
# META
|
||||||
HAS_BIAS=bias is not None,
|
HAS_BIAS=bias is not None,
|
||||||
KERNEL_WIDTH=width,
|
KERNEL_WIDTH=width,
|
||||||
@@ -1234,7 +1237,7 @@ def causal_conv1d_update(
|
|||||||
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
|
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
|
||||||
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
||||||
NP2_STATELEN=np2_statelen,
|
NP2_STATELEN=np2_statelen,
|
||||||
USE_PAD_SLOT=pad_slot_id is not None,
|
HAS_NULL_BLOCK=null_block_id is not None,
|
||||||
BLOCK_N=256,
|
BLOCK_N=256,
|
||||||
)
|
)
|
||||||
if unsqueeze:
|
if unsqueeze:
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from packaging import version
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
|
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
|
||||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
|
||||||
|
|
||||||
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
|
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
|
||||||
|
|
||||||
@@ -75,7 +75,7 @@ def _selective_scan_update_kernel(
|
|||||||
out_ptr,
|
out_ptr,
|
||||||
state_batch_indices_ptr,
|
state_batch_indices_ptr,
|
||||||
dst_state_batch_indices_ptr,
|
dst_state_batch_indices_ptr,
|
||||||
pad_slot_id,
|
null_block_id,
|
||||||
num_accepted_tokens_ptr,
|
num_accepted_tokens_ptr,
|
||||||
cu_seqlens_ptr,
|
cu_seqlens_ptr,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
@@ -203,7 +203,7 @@ def _selective_scan_update_kernel(
|
|||||||
|
|
||||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||||
if HAS_STATE_BATCH_INDICES:
|
if HAS_STATE_BATCH_INDICES:
|
||||||
mask &= state_batch_idx != pad_slot_id
|
mask &= state_batch_idx != null_block_id
|
||||||
state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)
|
state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
|
||||||
if HAS_DT_BIAS:
|
if HAS_DT_BIAS:
|
||||||
@@ -257,7 +257,7 @@ def _selective_scan_update_kernel(
|
|||||||
if IS_SPEC_DECODING:
|
if IS_SPEC_DECODING:
|
||||||
dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T
|
dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T
|
||||||
token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64)
|
token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64)
|
||||||
if token_dst_idx != pad_slot_id:
|
if token_dst_idx != null_block_id:
|
||||||
token_dst_ptrs = (
|
token_dst_ptrs = (
|
||||||
state_ptr_base
|
state_ptr_base
|
||||||
+ token_dst_idx * stride_state_batch
|
+ token_dst_idx * stride_state_batch
|
||||||
@@ -329,7 +329,7 @@ def selective_state_update(
|
|||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
state_batch_indices=None,
|
state_batch_indices=None,
|
||||||
dst_state_batch_indices=None,
|
dst_state_batch_indices=None,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
null_block_id=NULL_BLOCK_ID,
|
||||||
out=None,
|
out=None,
|
||||||
num_accepted_tokens=None,
|
num_accepted_tokens=None,
|
||||||
cu_seqlens=None,
|
cu_seqlens=None,
|
||||||
@@ -348,12 +348,12 @@ def selective_state_update(
|
|||||||
D: (dim,) or (nheads, dim)
|
D: (dim,) or (nheads, dim)
|
||||||
z: (batch, dim) or (batch, nheads, dim)
|
z: (batch, dim) or (batch, nheads, dim)
|
||||||
dt_bias: (dim,) or (nheads, dim)
|
dt_bias: (dim,) or (nheads, dim)
|
||||||
pad_slot_id: int
|
null_block_id: int
|
||||||
if cache_indices is passed, lets the kernel identify padded
|
if state_batch_indices is passed, lets the kernel identify
|
||||||
entries that will not be processed,
|
padded entries that will not be processed,
|
||||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
for example: state_batch_indices = [null_block_id, 1, 20,
|
||||||
in this case, the kernel will not process entries at
|
null_block_id] in this case, the kernel will not process
|
||||||
indices 0 and 3
|
entries at indices 0 and 3
|
||||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||||
In-place updated.
|
In-place updated.
|
||||||
num_accepted_tokens: (batch,)
|
num_accepted_tokens: (batch,)
|
||||||
@@ -488,7 +488,7 @@ def selective_state_update(
|
|||||||
out,
|
out,
|
||||||
state_batch_indices,
|
state_batch_indices,
|
||||||
dst_state_batch_indices,
|
dst_state_batch_indices,
|
||||||
pad_slot_id,
|
null_block_id,
|
||||||
num_accepted_tokens,
|
num_accepted_tokens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
N,
|
N,
|
||||||
@@ -550,7 +550,7 @@ def selective_scan_fn(
|
|||||||
query_start_loc=None,
|
query_start_loc=None,
|
||||||
cache_indices=None,
|
cache_indices=None,
|
||||||
has_initial_state=None,
|
has_initial_state=None,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
null_block_id=NULL_BLOCK_ID,
|
||||||
block_size=1024,
|
block_size=1024,
|
||||||
block_idx_first_scheduled_token=None,
|
block_idx_first_scheduled_token=None,
|
||||||
block_idx_last_scheduled_token=None,
|
block_idx_last_scheduled_token=None,
|
||||||
@@ -588,10 +588,10 @@ def selective_scan_fn(
|
|||||||
indicate if the ssm_state at the corresponding index should be
|
indicate if the ssm_state at the corresponding index should be
|
||||||
used as initial state. Not providing argument assumes
|
used as initial state. Not providing argument assumes
|
||||||
there's no initial state
|
there's no initial state
|
||||||
pad_slot_id: int
|
null_block_id: int
|
||||||
if cache_indices is passed, lets the kernel identify padding entries
|
if cache_indices is passed, lets the kernel identify padding entries
|
||||||
that will not be processed,
|
that will not be processed,
|
||||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
for example: cache_indices = [null_block_id, 1 ,20 ,null_block_id]
|
||||||
in this case, the kernel will not process entries at indices 0 and 3
|
in this case, the kernel will not process entries at indices 0 and 3
|
||||||
block_size: int
|
block_size: int
|
||||||
The block size to align the cached states to
|
The block size to align the cached states to
|
||||||
@@ -643,7 +643,7 @@ def selective_scan_fn(
|
|||||||
cache_indices,
|
cache_indices,
|
||||||
has_initial_state,
|
has_initial_state,
|
||||||
ssm_states,
|
ssm_states,
|
||||||
pad_slot_id,
|
null_block_id,
|
||||||
block_size,
|
block_size,
|
||||||
block_idx_first_scheduled_token,
|
block_idx_first_scheduled_token,
|
||||||
block_idx_last_scheduled_token,
|
block_idx_last_scheduled_token,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from vllm.v1.attention.backend import (
|
|||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
PAD_SLOT_ID,
|
NULL_BLOCK_ID,
|
||||||
compute_causal_conv1d_metadata,
|
compute_causal_conv1d_metadata,
|
||||||
mamba_get_block_table_tensor,
|
mamba_get_block_table_tensor,
|
||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
@@ -341,7 +341,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
|||||||
spec_state_indices_tensor, non_blocking=True
|
spec_state_indices_tensor, non_blocking=True
|
||||||
)
|
)
|
||||||
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
|
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
|
||||||
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
spec_state_indices_tensor[num_spec_decodes:].fill_(NULL_BLOCK_ID)
|
||||||
|
|
||||||
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||||
spec_sequence_masks[:num_spec_decodes], non_blocking=True
|
spec_sequence_masks[:num_spec_decodes], non_blocking=True
|
||||||
@@ -387,7 +387,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
|||||||
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
|
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
|
||||||
:batch_size
|
:batch_size
|
||||||
]
|
]
|
||||||
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
|
non_spec_state_indices_tensor[num_decodes:].fill_(NULL_BLOCK_ID)
|
||||||
|
|
||||||
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
|
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
|
||||||
non_spec_query_start_loc, non_blocking=True
|
non_spec_query_start_loc, non_blocking=True
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from vllm.v1.attention.backend import (
|
|||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
PAD_SLOT_ID,
|
NULL_BLOCK_ID,
|
||||||
compute_causal_conv1d_metadata,
|
compute_causal_conv1d_metadata,
|
||||||
mamba_get_block_table_tensor,
|
mamba_get_block_table_tensor,
|
||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
@@ -504,7 +504,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
|||||||
state_indices_tensor_d, non_blocking=True
|
state_indices_tensor_d, non_blocking=True
|
||||||
)
|
)
|
||||||
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
|
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
|
||||||
state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID
|
state_indices_tensor_d[metadata.num_decodes :] = NULL_BLOCK_ID
|
||||||
|
|
||||||
if self.use_spec_decode and num_accepted_tokens is not None:
|
if self.use_spec_decode and num_accepted_tokens is not None:
|
||||||
assert query_start_loc_d is not None
|
assert query_start_loc_d is not None
|
||||||
|
|||||||
@@ -366,12 +366,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
|||||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||||
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
|
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
|
||||||
|
|
||||||
# Padded CUDA graph requests have block_table entries of -1.
|
|
||||||
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
|
|
||||||
# This is safe because padded requests have seq_lens=0, so the
|
|
||||||
# kernel produces no meaningful output for those rows.
|
|
||||||
block_table.clamp_(min=0)
|
|
||||||
|
|
||||||
max_decode_len = int(decode_lens_cpu.max().item())
|
max_decode_len = int(decode_lens_cpu.max().item())
|
||||||
next_n = 1 + self.num_speculative_tokens
|
next_n = 1 + self.num_speculative_tokens
|
||||||
use_native = not self.use_flattening and max_decode_len == next_n
|
use_native = not self.use_flattening and max_decode_len == next_n
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ KVCacheLayoutType = Literal["NHD", "HND"]
|
|||||||
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
|
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
|
||||||
|
|
||||||
PAD_SLOT_ID = -1
|
PAD_SLOT_ID = -1
|
||||||
|
NULL_BLOCK_ID = 0
|
||||||
|
|
||||||
|
|
||||||
def is_valid_kv_cache_layout(value: str) -> bool:
|
def is_valid_kv_cache_layout(value: str) -> bool:
|
||||||
|
|||||||
@@ -122,6 +122,7 @@ from vllm.v1.attention.backend import (
|
|||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder
|
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
NULL_BLOCK_ID,
|
||||||
create_fast_prefill_custom_backend,
|
create_fast_prefill_custom_backend,
|
||||||
get_dcp_local_seq_lens,
|
get_dcp_local_seq_lens,
|
||||||
reorder_batch_to_split_decodes_and_prefills,
|
reorder_batch_to_split_decodes_and_prefills,
|
||||||
@@ -2135,9 +2136,9 @@ class GPUModelRunner(
|
|||||||
blk_table = self.input_batch.block_table[kv_cache_gid]
|
blk_table = self.input_batch.block_table[kv_cache_gid]
|
||||||
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
|
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
|
||||||
|
|
||||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
# Fill unused block table entries with NULL_BLOCK_ID (null block)
|
||||||
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
|
# for CUDAGraph padding. Block 0 is reserved for padding.
|
||||||
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
|
blk_table_tensor[num_reqs:num_reqs_padded].fill_(NULL_BLOCK_ID)
|
||||||
return blk_table_tensor
|
return blk_table_tensor
|
||||||
|
|
||||||
assert slot_mappings is not None
|
assert slot_mappings is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user