[Bugfix] Use null block (0) for padded block table entries (#35431)
Signed-off-by: SandishKumarHN <sandish@fb.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -21,7 +21,7 @@ struct SSMParamsBase {
|
||||
int dim_ngroups_ratio;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
int64_t pad_slot_id;
|
||||
int64_t null_block_id;
|
||||
|
||||
bool delta_softplus;
|
||||
bool cache_enabled;
|
||||
|
||||
@@ -118,9 +118,17 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
|
||||
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
int cache_index;
|
||||
if (cache_indices == nullptr) {
|
||||
cache_index = batch_id;
|
||||
} else if (params.cache_enabled) {
|
||||
const int* initial_state_idx = reinterpret_cast<const int*>(params.initial_state_idx_ptr);
|
||||
cache_index = cache_indices[batch_id * params.cache_indices_stride + initial_state_idx[batch_id]];
|
||||
} else {
|
||||
cache_index = cache_indices[batch_id];
|
||||
}
|
||||
// Skip batch entries whose cache index maps to the null block (padding).
|
||||
if (cache_indices != nullptr && cache_index == params.null_block_id){
|
||||
return;
|
||||
}
|
||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
|
||||
@@ -527,7 +535,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool varlen,
|
||||
int64_t pad_slot_id,
|
||||
int64_t null_block_id,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
@@ -544,7 +552,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.dstate = dstate;
|
||||
params.n_groups = n_groups;
|
||||
params.dim_ngroups_ratio = dim / n_groups;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
params.null_block_id = null_block_id;
|
||||
|
||||
params.delta_softplus = delta_softplus;
|
||||
|
||||
@@ -658,7 +666,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const torch::Tensor &ssm_states,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id,
|
||||
int64_t null_block_id,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
@@ -805,7 +813,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
varlen,
|
||||
pad_slot_id,
|
||||
null_block_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
|
||||
@@ -298,7 +298,7 @@ void selective_scan_fwd(
|
||||
const std::optional<torch::Tensor>& query_start_loc,
|
||||
const std::optional<torch::Tensor>& cache_indices,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
|
||||
const torch::Tensor& ssm_states, int64_t null_block_id, int64_t block_size,
|
||||
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
||||
const std::optional<torch::Tensor>& initial_state_idx,
|
||||
|
||||
@@ -556,7 +556,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"Tensor! ssm_states,"
|
||||
"int pad_slot_id,"
|
||||
"int null_block_id,"
|
||||
"int block_size,"
|
||||
"Tensor? block_idx_first_scheduled_token,"
|
||||
"Tensor? block_idx_last_scheduled_token,"
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
@@ -122,7 +122,7 @@ def causal_conv1d_opcheck_fn(
|
||||
has_initial_state: torch.Tensor | None = None,
|
||||
conv_states: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
null_block_id: int = NULL_BLOCK_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
@@ -158,15 +158,16 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
|
||||
# +1 entry to reserve index 0 as null block
|
||||
conv_state = torch.randn(batch + 1, dim, width - 1, device=device, dtype=itype)
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state.detach().clone()
|
||||
# Start indices from 1, skipping null block at index 0
|
||||
conv_state_indices = torch.arange(1, batch + 1, dtype=torch.int32, device=device)
|
||||
conv_state_ref = conv_state[conv_state_indices].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
|
||||
|
||||
out = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
@@ -179,7 +180,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
|
||||
x_ref, conv_state_ref, weight, bias, activation=activation
|
||||
)
|
||||
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.equal(conv_state[conv_state_indices], conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@@ -215,7 +216,8 @@ def test_causal_conv1d_update_with_batch_gather(
|
||||
|
||||
x_ref = x.clone()
|
||||
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
# +1 to exclude index 0 (null block)
|
||||
conv_state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
@@ -223,7 +225,9 @@ def test_causal_conv1d_update_with_batch_gather(
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
conv_state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
torch.as_tensor(
|
||||
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||
),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
@@ -248,7 +252,6 @@ def test_causal_conv1d_update_with_batch_gather(
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
|
||||
@@ -317,13 +320,19 @@ def test_causal_conv1d_varlen(
|
||||
has_initial_states = torch.randint(
|
||||
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
|
||||
)
|
||||
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
|
||||
:batch_size
|
||||
]
|
||||
# +1 to exclude index 0 (null block)
|
||||
state_indices = (
|
||||
torch.randperm(total_entries - 1, dtype=torch.int32, device=x.device)[
|
||||
:batch_size
|
||||
]
|
||||
+ 1
|
||||
)
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
torch.as_tensor(
|
||||
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
@@ -336,7 +345,6 @@ def test_causal_conv1d_varlen(
|
||||
cache_indices=padded_state_indices,
|
||||
has_initial_state=has_initial_states,
|
||||
activation=activation,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
out_ref = []
|
||||
@@ -345,7 +353,7 @@ def test_causal_conv1d_varlen(
|
||||
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
|
||||
for i in range(len(seqlens[0])):
|
||||
x_s = [v[i].unsqueeze(0) for v in splits][0]
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
if padded_state_indices[i] == NULL_BLOCK_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
|
||||
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
|
||||
|
||||
|
||||
def selective_state_update_ref(
|
||||
@@ -179,7 +179,7 @@ def selective_scan_opcheck_fn(
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
ssm_states=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
null_block_id=NULL_BLOCK_ID,
|
||||
block_size=2048,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
@@ -229,7 +229,7 @@ def selective_scan_opcheck_fn(
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
null_block_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
@@ -351,7 +351,6 @@ def test_selective_scan(
|
||||
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
|
||||
if c > 0
|
||||
else None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
block_size=2048,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
@@ -652,15 +651,21 @@ def test_selective_scan_varlen(
|
||||
prev_state_shape, device=u.device, dtype=itype, requires_grad=False
|
||||
)
|
||||
prev_state_ref = prev_state.clone()
|
||||
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[
|
||||
:batch_size
|
||||
]
|
||||
# +1 to exclude index 0 (null block)
|
||||
state_indices = (
|
||||
torch.randperm(total_entries - 1, dtype=torch.int32, device=u.device)[
|
||||
:batch_size
|
||||
]
|
||||
+ 1
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
unused_states_bool[state_indices] = False
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
torch.as_tensor(
|
||||
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
@@ -690,7 +695,7 @@ def test_selective_scan_varlen(
|
||||
]
|
||||
for i in range(len(seqlens[0])):
|
||||
u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
if padded_state_indices[i] == NULL_BLOCK_ID:
|
||||
continue
|
||||
out_ref_s, _ = selective_scan_ref(
|
||||
u_s,
|
||||
@@ -758,7 +763,8 @@ def test_selective_state_update_with_batch_indices(
|
||||
padded_batch_size = batch_size + padding
|
||||
total_entries = 10 * batch_size
|
||||
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
# +1 to exclude index 0 (null block)
|
||||
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
@@ -766,7 +772,9 @@ def test_selective_state_update_with_batch_indices(
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
torch.as_tensor(
|
||||
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||
),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
@@ -793,7 +801,6 @@ def test_selective_state_update_with_batch_indices(
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out,
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
@@ -849,7 +856,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
||||
state = torch.randn(
|
||||
total_entries, nheads, headdim, dstate, dtype=itype, device=device
|
||||
)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
# +1 to exclude index 0 (null block)
|
||||
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
@@ -887,7 +895,6 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out,
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
@@ -935,17 +942,18 @@ def test_selective_state_update_with_num_accepted_tokens(
|
||||
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
|
||||
|
||||
state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
# Start from 1 to exclude null block at index 0
|
||||
initial_state_slots = torch.randint(
|
||||
0, 15, (batch_size,), device=device, dtype=torch.int32
|
||||
1, 15, (batch_size,), device=device, dtype=torch.int32
|
||||
)
|
||||
for seq_idx in range(batch_size):
|
||||
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
||||
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
|
||||
|
||||
dst_state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
slot_offset = 15
|
||||
dst_slots_map = {}
|
||||
@@ -1013,7 +1021,6 @@ def test_selective_state_update_with_num_accepted_tokens(
|
||||
state_batch_indices=state_batch_indices,
|
||||
dst_state_batch_indices=dst_state_batch_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
@@ -1061,18 +1068,19 @@ def test_selective_state_update_varlen_with_num_accepted(
|
||||
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
|
||||
|
||||
state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Start from 1 to exclude null block at index 0
|
||||
initial_state_slots = torch.randint(
|
||||
0, 15, (batch_size,), device=device, dtype=torch.int32
|
||||
1, 15, (batch_size,), device=device, dtype=torch.int32
|
||||
)
|
||||
for seq_idx in range(batch_size):
|
||||
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
||||
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
|
||||
|
||||
dst_state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
slot_offset = 15
|
||||
@@ -1138,7 +1146,6 @@ def test_selective_state_update_varlen_with_num_accepted(
|
||||
state_batch_indices=state_batch_indices,
|
||||
dst_state_batch_indices=dst_state_batch_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
for seq_idx in range(batch_size):
|
||||
|
||||
@@ -2062,7 +2062,7 @@ def selective_scan_fwd(
|
||||
cache_indices: torch.Tensor | None,
|
||||
has_initial_state: torch.Tensor | None,
|
||||
ssm_states: torch.Tensor,
|
||||
pad_slot_id: int,
|
||||
null_block_id: int,
|
||||
block_size: int = 1024,
|
||||
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||
@@ -2084,7 +2084,7 @@ def selective_scan_fwd(
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
null_block_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -602,6 +603,7 @@ def _linear_attn_decode_kernel(
|
||||
cache_h_stride,
|
||||
cache_d0_stride,
|
||||
cache_d1_stride,
|
||||
pad_slot_id: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
@@ -616,8 +618,8 @@ def _linear_attn_decode_kernel(
|
||||
# Load slot index for the current batch
|
||||
slot_id = tl.load(slot_idx + pid_b).to(tl.int64)
|
||||
|
||||
# Skip if slot_id is -1 (padding)
|
||||
if slot_id == -1:
|
||||
# Skip if slot_id is PAD_SLOT_ID (padding)
|
||||
if slot_id == pad_slot_id:
|
||||
return
|
||||
|
||||
batch_id = pid_b
|
||||
@@ -727,6 +729,7 @@ def linear_decode_forward_triton(
|
||||
cache_h_stride,
|
||||
cache_d0_stride,
|
||||
cache_d1_stride,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID, PAD_SLOT_ID
|
||||
|
||||
|
||||
@triton.jit()
|
||||
@@ -49,12 +49,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
null_block_id: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_APC_ENABLED: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
HAS_NULL_BLOCK: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
@@ -133,9 +134,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index
|
||||
).to(tl.int64)
|
||||
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_states_input_coord == pad_slot_id:
|
||||
# not processing as this is not the actual sequence
|
||||
if HAS_NULL_BLOCK: # noqa
|
||||
if conv_states_input_coord == null_block_id:
|
||||
# not processing as this is a null block (padding)
|
||||
return
|
||||
conv_states_base = (
|
||||
conv_states_ptr
|
||||
@@ -475,6 +476,7 @@ def causal_conv1d_fn(
|
||||
has_initial_state: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
null_block_id: int = NULL_BLOCK_ID,
|
||||
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||
initial_state_idx: torch.Tensor | None = None,
|
||||
@@ -730,12 +732,13 @@ def causal_conv1d_fn(
|
||||
block_size_to_align // BLOCK_M,
|
||||
# others
|
||||
pad_slot_id,
|
||||
null_block_id,
|
||||
# META
|
||||
HAS_BIAS=bias is not None,
|
||||
KERNEL_WIDTH=width,
|
||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
HAS_NULL_BLOCK=null_block_id is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
# launch_cooperative_grid=True
|
||||
BLOCK_M=BLOCK_M,
|
||||
@@ -778,7 +781,7 @@ def _causal_conv1d_update_kernel(
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
null_block_id: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
@@ -787,7 +790,7 @@ def _causal_conv1d_update_kernel(
|
||||
IS_APC_ENABLED: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
HAS_NULL_BLOCK: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
# ruff: noqa: E501
|
||||
@@ -811,8 +814,8 @@ def _causal_conv1d_update_kernel(
|
||||
conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init
|
||||
).to(tl.int64)
|
||||
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_states_input_coord == pad_slot_id:
|
||||
if HAS_NULL_BLOCK: # noqa
|
||||
if conv_states_input_coord == null_block_id:
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
|
||||
@@ -1076,7 +1079,7 @@ def causal_conv1d_update(
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
query_start_loc: torch.Tensor | None = None,
|
||||
max_query_len: int = -1,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
null_block_id: int = NULL_BLOCK_ID,
|
||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||
initial_state_idx: torch.Tensor | None = None,
|
||||
validate_data=False,
|
||||
@@ -1111,16 +1114,16 @@ def causal_conv1d_update(
|
||||
max_query_len: int
|
||||
If query_start_loc is not None, this indicates the maximum query
|
||||
length in the batch.
|
||||
pad_slot_id: int
|
||||
if conv_state_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
null_block_id: int
|
||||
Block ID used to identify padded entries in
|
||||
conv_state_indices. Block 0 is the null block.
|
||||
for example: conv_state_indices = [null_block_id, 1, 20, null_block_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
|
||||
"""
|
||||
if validate_data:
|
||||
assert pad_slot_id is not None
|
||||
assert null_block_id is not None
|
||||
assert x.stride(1) == 1
|
||||
if isinstance(activation, bool):
|
||||
activation = "silu" if activation is True else None
|
||||
@@ -1225,7 +1228,7 @@ def causal_conv1d_update(
|
||||
stride_o_dim,
|
||||
stride_o_token,
|
||||
# others
|
||||
pad_slot_id,
|
||||
null_block_id,
|
||||
# META
|
||||
HAS_BIAS=bias is not None,
|
||||
KERNEL_WIDTH=width,
|
||||
@@ -1234,7 +1237,7 @@ def causal_conv1d_update(
|
||||
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
|
||||
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
HAS_NULL_BLOCK=null_block_id is not None,
|
||||
BLOCK_N=256,
|
||||
)
|
||||
if unsqueeze:
|
||||
|
||||
@@ -10,7 +10,7 @@ from packaging import version
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
|
||||
|
||||
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
|
||||
|
||||
@@ -75,7 +75,7 @@ def _selective_scan_update_kernel(
|
||||
out_ptr,
|
||||
state_batch_indices_ptr,
|
||||
dst_state_batch_indices_ptr,
|
||||
pad_slot_id,
|
||||
null_block_id,
|
||||
num_accepted_tokens_ptr,
|
||||
cu_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
@@ -203,7 +203,7 @@ def _selective_scan_update_kernel(
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
mask &= state_batch_idx != null_block_id
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_DT_BIAS:
|
||||
@@ -257,7 +257,7 @@ def _selective_scan_update_kernel(
|
||||
if IS_SPEC_DECODING:
|
||||
dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T
|
||||
token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64)
|
||||
if token_dst_idx != pad_slot_id:
|
||||
if token_dst_idx != null_block_id:
|
||||
token_dst_ptrs = (
|
||||
state_ptr_base
|
||||
+ token_dst_idx * stride_state_batch
|
||||
@@ -329,7 +329,7 @@ def selective_state_update(
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
dst_state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
null_block_id=NULL_BLOCK_ID,
|
||||
out=None,
|
||||
num_accepted_tokens=None,
|
||||
cu_seqlens=None,
|
||||
@@ -348,12 +348,12 @@ def selective_state_update(
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
null_block_id: int
|
||||
if state_batch_indices is passed, lets the kernel identify
|
||||
padded entries that will not be processed,
|
||||
for example: state_batch_indices = [null_block_id, 1, 20,
|
||||
null_block_id] in this case, the kernel will not process
|
||||
entries at indices 0 and 3
|
||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||
In-place updated.
|
||||
num_accepted_tokens: (batch,)
|
||||
@@ -488,7 +488,7 @@ def selective_state_update(
|
||||
out,
|
||||
state_batch_indices,
|
||||
dst_state_batch_indices,
|
||||
pad_slot_id,
|
||||
null_block_id,
|
||||
num_accepted_tokens,
|
||||
cu_seqlens,
|
||||
N,
|
||||
@@ -550,7 +550,7 @@ def selective_scan_fn(
|
||||
query_start_loc=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
null_block_id=NULL_BLOCK_ID,
|
||||
block_size=1024,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
@@ -588,10 +588,10 @@ def selective_scan_fn(
|
||||
indicate if the ssm_state at the corresponding index should be
|
||||
used as initial state. Not providing argument assumes
|
||||
there's no initial state
|
||||
pad_slot_id: int
|
||||
null_block_id: int
|
||||
if cache_indices is passed, lets the kernel identify padding entries
|
||||
that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
for example: cache_indices = [null_block_id, 1 ,20 ,null_block_id]
|
||||
in this case, the kernel will not process entries at indices 0 and 3
|
||||
block_size: int
|
||||
The block size to align the cached states to
|
||||
@@ -643,7 +643,7 @@ def selective_scan_fn(
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
null_block_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
|
||||
@@ -14,7 +14,7 @@ from vllm.v1.attention.backend import (
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
NULL_BLOCK_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
@@ -341,7 +341,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
|
||||
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||
spec_state_indices_tensor[num_spec_decodes:].fill_(NULL_BLOCK_ID)
|
||||
|
||||
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||
spec_sequence_masks[:num_spec_decodes], non_blocking=True
|
||||
@@ -387,7 +387,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
|
||||
:batch_size
|
||||
]
|
||||
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
|
||||
non_spec_state_indices_tensor[num_decodes:].fill_(NULL_BLOCK_ID)
|
||||
|
||||
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
|
||||
non_spec_query_start_loc, non_blocking=True
|
||||
|
||||
@@ -15,7 +15,7 @@ from vllm.v1.attention.backend import (
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
NULL_BLOCK_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
@@ -504,7 +504,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
state_indices_tensor_d, non_blocking=True
|
||||
)
|
||||
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
|
||||
state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID
|
||||
state_indices_tensor_d[metadata.num_decodes :] = NULL_BLOCK_ID
|
||||
|
||||
if self.use_spec_decode and num_accepted_tokens is not None:
|
||||
assert query_start_loc_d is not None
|
||||
|
||||
@@ -366,12 +366,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
|
||||
|
||||
# Padded CUDA graph requests have block_table entries of -1.
|
||||
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
|
||||
# This is safe because padded requests have seq_lens=0, so the
|
||||
# kernel produces no meaningful output for those rows.
|
||||
block_table.clamp_(min=0)
|
||||
|
||||
max_decode_len = int(decode_lens_cpu.max().item())
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
use_native = not self.use_flattening and max_decode_len == next_n
|
||||
|
||||
@@ -42,6 +42,7 @@ KVCacheLayoutType = Literal["NHD", "HND"]
|
||||
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
NULL_BLOCK_ID = 0
|
||||
|
||||
|
||||
def is_valid_kv_cache_layout(value: str) -> bool:
|
||||
|
||||
@@ -122,6 +122,7 @@ from vllm.v1.attention.backend import (
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
NULL_BLOCK_ID,
|
||||
create_fast_prefill_custom_backend,
|
||||
get_dcp_local_seq_lens,
|
||||
reorder_batch_to_split_decodes_and_prefills,
|
||||
@@ -2135,9 +2136,9 @@ class GPUModelRunner(
|
||||
blk_table = self.input_batch.block_table[kv_cache_gid]
|
||||
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
|
||||
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
|
||||
# Fill unused block table entries with NULL_BLOCK_ID (null block)
|
||||
# for CUDAGraph padding. Block 0 is reserved for padding.
|
||||
blk_table_tensor[num_reqs:num_reqs_padded].fill_(NULL_BLOCK_ID)
|
||||
return blk_table_tensor
|
||||
|
||||
assert slot_mappings is not None
|
||||
|
||||
Reference in New Issue
Block a user