[Bugfix] Use null block (0) for padded block table entries (#35431)

Signed-off-by: SandishKumarHN <sandish@fb.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
SandishKumarHN
2026-03-30 14:02:51 -07:00
committed by GitHub
parent 1fc69f59bb
commit bcc6f67447
15 changed files with 125 additions and 100 deletions

View File

@@ -21,7 +21,7 @@ struct SSMParamsBase {
int dim_ngroups_ratio; int dim_ngroups_ratio;
bool is_variable_B; bool is_variable_B;
bool is_variable_C; bool is_variable_C;
int64_t pad_slot_id; int64_t null_block_id;
bool delta_softplus; bool delta_softplus;
bool cache_enabled; bool cache_enabled;

View File

@@ -118,9 +118,17 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr); : reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; int cache_index;
// cache_index == params.pad_slot_id is defined as padding, so we exit early if (cache_indices == nullptr) {
if (cache_index == params.pad_slot_id){ cache_index = batch_id;
} else if (params.cache_enabled) {
const int* initial_state_idx = reinterpret_cast<const int*>(params.initial_state_idx_ptr);
cache_index = cache_indices[batch_id * params.cache_indices_stride + initial_state_idx[batch_id]];
} else {
cache_index = cache_indices[batch_id];
}
// Skip batch entries whose cache index maps to the null block (padding).
if (cache_indices != nullptr && cache_index == params.null_block_id){
return; return;
} }
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
@@ -527,7 +535,7 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const std::optional<at::Tensor>& cache_indices, const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state, const std::optional<at::Tensor>& has_initial_state,
bool varlen, bool varlen,
int64_t pad_slot_id, int64_t null_block_id,
int64_t block_size, int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token, const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token, const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
@@ -544,7 +552,7 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.dstate = dstate; params.dstate = dstate;
params.n_groups = n_groups; params.n_groups = n_groups;
params.dim_ngroups_ratio = dim / n_groups; params.dim_ngroups_ratio = dim / n_groups;
params.pad_slot_id = pad_slot_id; params.null_block_id = null_block_id;
params.delta_softplus = delta_softplus; params.delta_softplus = delta_softplus;
@@ -658,7 +666,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &ssm_states, const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided // used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early // in case of padding, the kernel will return early
int64_t pad_slot_id, int64_t null_block_id,
int64_t block_size, int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token, const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token, const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
@@ -805,7 +813,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
cache_indices, cache_indices,
has_initial_state, has_initial_state,
varlen, varlen,
pad_slot_id, null_block_id,
block_size, block_size,
block_idx_first_scheduled_token, block_idx_first_scheduled_token,
block_idx_last_scheduled_token, block_idx_last_scheduled_token,

View File

@@ -298,7 +298,7 @@ void selective_scan_fwd(
const std::optional<torch::Tensor>& query_start_loc, const std::optional<torch::Tensor>& query_start_loc,
const std::optional<torch::Tensor>& cache_indices, const std::optional<torch::Tensor>& cache_indices,
const std::optional<torch::Tensor>& has_initial_state, const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size, const torch::Tensor& ssm_states, int64_t null_block_id, int64_t block_size,
const std::optional<torch::Tensor>& block_idx_first_scheduled_token, const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
const std::optional<torch::Tensor>& block_idx_last_scheduled_token, const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
const std::optional<torch::Tensor>& initial_state_idx, const std::optional<torch::Tensor>& initial_state_idx,

View File

@@ -556,7 +556,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? cache_indices," "Tensor? cache_indices,"
"Tensor? has_initial_state," "Tensor? has_initial_state,"
"Tensor! ssm_states," "Tensor! ssm_states,"
"int pad_slot_id," "int null_block_id,"
"int block_size," "int block_size,"
"Tensor? block_idx_first_scheduled_token," "Tensor? block_idx_first_scheduled_token,"
"Tensor? block_idx_last_scheduled_token," "Tensor? block_idx_last_scheduled_token,"

View File

@@ -12,7 +12,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_update, causal_conv1d_update,
) )
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
def causal_conv1d_ref( def causal_conv1d_ref(
@@ -122,7 +122,7 @@ def causal_conv1d_opcheck_fn(
has_initial_state: torch.Tensor | None = None, has_initial_state: torch.Tensor | None = None,
conv_states: torch.Tensor | None = None, conv_states: torch.Tensor | None = None,
activation: str | None = "silu", activation: str | None = "silu",
pad_slot_id: int = PAD_SLOT_ID, null_block_id: int = NULL_BLOCK_ID,
): ):
""" """
x: (batch, dim, seqlen) x: (batch, dim, seqlen)
@@ -158,15 +158,16 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
batch = 2 batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone() x_ref = x.clone()
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype) # +1 entry to reserve index 0 as null block
conv_state = torch.randn(batch + 1, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim, width, device=device, dtype=itype) weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone() # Start indices from 1, skipping null block at index 0
conv_state_indices = torch.arange(1, batch + 1, dtype=torch.int32, device=device)
conv_state_ref = conv_state[conv_state_indices].detach().clone()
activation = None if not silu_activation else "silu" activation = None if not silu_activation else "silu"
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
out = causal_conv1d_update( out = causal_conv1d_update(
x, x,
conv_state, conv_state,
@@ -179,7 +180,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
x_ref, conv_state_ref, weight, bias, activation=activation x_ref, conv_state_ref, weight, bias, activation=activation
) )
assert torch.equal(conv_state, conv_state_ref) assert torch.equal(conv_state[conv_state_indices], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@@ -215,7 +216,8 @@ def test_causal_conv1d_update_with_batch_gather(
x_ref = x.clone() x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to( # +1 to exclude index 0 (null block)
conv_state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
dtype=torch.int32, device=device dtype=torch.int32, device=device
) )
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
@@ -223,7 +225,9 @@ def test_causal_conv1d_update_with_batch_gather(
padded_state_indices = torch.concat( padded_state_indices = torch.concat(
[ [
conv_state_indices, conv_state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
], ],
dim=0, dim=0,
) )
@@ -248,7 +252,6 @@ def test_causal_conv1d_update_with_batch_gather(
bias, bias,
activation=activation, activation=activation,
conv_state_indices=padded_state_indices, conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
) )
out_ref = causal_conv1d_update_ref( out_ref = causal_conv1d_update_ref(
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
@@ -317,13 +320,19 @@ def test_causal_conv1d_varlen(
has_initial_states = torch.randint( has_initial_states = torch.randint(
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device 0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
) )
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[ # +1 to exclude index 0 (null block)
:batch_size state_indices = (
] torch.randperm(total_entries - 1, dtype=torch.int32, device=x.device)[
:batch_size
]
+ 1
)
padded_state_indices = torch.concat( padded_state_indices = torch.concat(
[ [
state_indices, state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
], ],
dim=-1, dim=-1,
) )
@@ -336,7 +345,6 @@ def test_causal_conv1d_varlen(
cache_indices=padded_state_indices, cache_indices=padded_state_indices,
has_initial_state=has_initial_states, has_initial_state=has_initial_states,
activation=activation, activation=activation,
pad_slot_id=PAD_SLOT_ID,
) )
out_ref = [] out_ref = []
@@ -345,7 +353,7 @@ def test_causal_conv1d_varlen(
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)] splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])): for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0] x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID: if padded_state_indices[i] == NULL_BLOCK_ID:
continue continue
out_ref_b.append( out_ref_b.append(
causal_conv1d_ref( causal_conv1d_ref(

View File

@@ -14,7 +14,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
def selective_state_update_ref( def selective_state_update_ref(
@@ -179,7 +179,7 @@ def selective_scan_opcheck_fn(
cache_indices=None, cache_indices=None,
has_initial_state=None, has_initial_state=None,
ssm_states=None, ssm_states=None,
pad_slot_id=PAD_SLOT_ID, null_block_id=NULL_BLOCK_ID,
block_size=2048, block_size=2048,
block_idx_first_scheduled_token=None, block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None, block_idx_last_scheduled_token=None,
@@ -229,7 +229,7 @@ def selective_scan_opcheck_fn(
cache_indices, cache_indices,
has_initial_state, has_initial_state,
ssm_states, ssm_states,
pad_slot_id, null_block_id,
block_size, block_size,
block_idx_first_scheduled_token, block_idx_first_scheduled_token,
block_idx_last_scheduled_token, block_idx_last_scheduled_token,
@@ -351,7 +351,6 @@ def test_selective_scan(
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool) has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
if c > 0 if c > 0
else None, else None,
pad_slot_id=PAD_SLOT_ID,
block_size=2048, block_size=2048,
block_idx_first_scheduled_token=None, block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None, block_idx_last_scheduled_token=None,
@@ -652,15 +651,21 @@ def test_selective_scan_varlen(
prev_state_shape, device=u.device, dtype=itype, requires_grad=False prev_state_shape, device=u.device, dtype=itype, requires_grad=False
) )
prev_state_ref = prev_state.clone() prev_state_ref = prev_state.clone()
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[ # +1 to exclude index 0 (null block)
:batch_size state_indices = (
] torch.randperm(total_entries - 1, dtype=torch.int32, device=u.device)[
:batch_size
]
+ 1
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[state_indices] = False unused_states_bool[state_indices] = False
padded_state_indices = torch.concat( padded_state_indices = torch.concat(
[ [
state_indices, state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
], ],
dim=-1, dim=-1,
) )
@@ -690,7 +695,7 @@ def test_selective_scan_varlen(
] ]
for i in range(len(seqlens[0])): for i in range(len(seqlens[0])):
u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits) u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
if padded_state_indices[i] == PAD_SLOT_ID: if padded_state_indices[i] == NULL_BLOCK_ID:
continue continue
out_ref_s, _ = selective_scan_ref( out_ref_s, _ = selective_scan_ref(
u_s, u_s,
@@ -758,7 +763,8 @@ def test_selective_state_update_with_batch_indices(
padded_batch_size = batch_size + padding padded_batch_size = batch_size + padding
total_entries = 10 * batch_size total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to( # +1 to exclude index 0 (null block)
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
dtype=torch.int32, device=device dtype=torch.int32, device=device
) )
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
@@ -766,7 +772,9 @@ def test_selective_state_update_with_batch_indices(
padded_state_indices = torch.concat( padded_state_indices = torch.concat(
[ [
state_indices, state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
], ],
dim=0, dim=0,
) )
@@ -793,7 +801,6 @@ def test_selective_state_update_with_batch_indices(
dt_bias=dt_bias, dt_bias=dt_bias,
dt_softplus=True, dt_softplus=True,
state_batch_indices=padded_state_indices, state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out, out=out,
) )
out_ref = selective_state_update_ref( out_ref = selective_state_update_ref(
@@ -849,7 +856,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
state = torch.randn( state = torch.randn(
total_entries, nheads, headdim, dstate, dtype=itype, device=device total_entries, nheads, headdim, dstate, dtype=itype, device=device
) )
state_indices = torch.randperm(total_entries)[:batch_size].to( # +1 to exclude index 0 (null block)
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
dtype=torch.int32, device=device dtype=torch.int32, device=device
) )
@@ -887,7 +895,6 @@ def test_selective_state_update_with_heads_with_batch_indices(
dt_bias=dt_bias, dt_bias=dt_bias,
dt_softplus=True, dt_softplus=True,
state_batch_indices=state_indices, state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out, out=out,
) )
out_ref = selective_state_update_ref( out_ref = selective_state_update_ref(
@@ -935,17 +942,18 @@ def test_selective_state_update_with_num_accepted_tokens(
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device) state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full( state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device (batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
) )
# Start from 1 to exclude null block at index 0
initial_state_slots = torch.randint( initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32 1, 15, (batch_size,), device=device, dtype=torch.int32
) )
for seq_idx in range(batch_size): for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx] state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full( dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device (batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
) )
slot_offset = 15 slot_offset = 15
dst_slots_map = {} dst_slots_map = {}
@@ -1013,7 +1021,6 @@ def test_selective_state_update_with_num_accepted_tokens(
state_batch_indices=state_batch_indices, state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices, dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
) )
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@@ -1061,18 +1068,19 @@ def test_selective_state_update_varlen_with_num_accepted(
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device) state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full( state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device (batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
) )
# Start from 1 to exclude null block at index 0
initial_state_slots = torch.randint( initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32 1, 15, (batch_size,), device=device, dtype=torch.int32
) )
for seq_idx in range(batch_size): for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx] state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full( dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device (batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
) )
slot_offset = 15 slot_offset = 15
@@ -1138,7 +1146,6 @@ def test_selective_state_update_varlen_with_num_accepted(
state_batch_indices=state_batch_indices, state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices, dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
) )
for seq_idx in range(batch_size): for seq_idx in range(batch_size):

View File

@@ -2062,7 +2062,7 @@ def selective_scan_fwd(
cache_indices: torch.Tensor | None, cache_indices: torch.Tensor | None,
has_initial_state: torch.Tensor | None, has_initial_state: torch.Tensor | None,
ssm_states: torch.Tensor, ssm_states: torch.Tensor,
pad_slot_id: int, null_block_id: int,
block_size: int = 1024, block_size: int = 1024,
block_idx_first_scheduled_token: torch.Tensor | None = None, block_idx_first_scheduled_token: torch.Tensor | None = None,
block_idx_last_scheduled_token: torch.Tensor | None = None, block_idx_last_scheduled_token: torch.Tensor | None = None,
@@ -2084,7 +2084,7 @@ def selective_scan_fwd(
cache_indices, cache_indices,
has_initial_state, has_initial_state,
ssm_states, ssm_states,
pad_slot_id, null_block_id,
block_size, block_size,
block_idx_first_scheduled_token, block_idx_first_scheduled_token,
block_idx_last_scheduled_token, block_idx_last_scheduled_token,

View File

@@ -5,6 +5,7 @@ import torch
from einops import rearrange from einops import rearrange
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
@triton.jit @triton.jit
@@ -602,6 +603,7 @@ def _linear_attn_decode_kernel(
cache_h_stride, cache_h_stride,
cache_d0_stride, cache_d0_stride,
cache_d1_stride, cache_d1_stride,
pad_slot_id: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
""" """
@@ -616,8 +618,8 @@ def _linear_attn_decode_kernel(
# Load slot index for the current batch # Load slot index for the current batch
slot_id = tl.load(slot_idx + pid_b).to(tl.int64) slot_id = tl.load(slot_idx + pid_b).to(tl.int64)
# Skip if slot_id is -1 (padding) # Skip if slot_id is PAD_SLOT_ID (padding)
if slot_id == -1: if slot_id == pad_slot_id:
return return
batch_id = pid_b batch_id = pid_b
@@ -727,6 +729,7 @@ def linear_decode_forward_triton(
cache_h_stride, cache_h_stride,
cache_d0_stride, cache_d0_stride,
cache_d1_stride, cache_d1_stride,
pad_slot_id=PAD_SLOT_ID,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
) )

View File

@@ -9,7 +9,7 @@ import numpy as np
import torch import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.attention.backends.utils import NULL_BLOCK_ID, PAD_SLOT_ID
@triton.jit() @triton.jit()
@@ -49,12 +49,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching
stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M
# others # others
pad_slot_id: tl.constexpr, pad_slot_id: tl.constexpr,
null_block_id: tl.constexpr,
# Meta-parameters # Meta-parameters
HAS_BIAS: tl.constexpr, HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr, KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr, SILU_ACTIVATION: tl.constexpr,
IS_APC_ENABLED: tl.constexpr, IS_APC_ENABLED: tl.constexpr,
USE_PAD_SLOT: tl.constexpr, HAS_NULL_BLOCK: tl.constexpr,
NP2_STATELEN: tl.constexpr, NP2_STATELEN: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
@@ -133,9 +134,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index
).to(tl.int64) ).to(tl.int64)
if USE_PAD_SLOT: # noqa if HAS_NULL_BLOCK: # noqa
if conv_states_input_coord == pad_slot_id: if conv_states_input_coord == null_block_id:
# not processing as this is not the actual sequence # not processing as this is a null block (padding)
return return
conv_states_base = ( conv_states_base = (
conv_states_ptr conv_states_ptr
@@ -475,6 +476,7 @@ def causal_conv1d_fn(
has_initial_state: torch.Tensor | None = None, has_initial_state: torch.Tensor | None = None,
activation: str | None = "silu", activation: str | None = "silu",
pad_slot_id: int = PAD_SLOT_ID, pad_slot_id: int = PAD_SLOT_ID,
null_block_id: int = NULL_BLOCK_ID,
block_idx_first_scheduled_token: torch.Tensor | None = None, block_idx_first_scheduled_token: torch.Tensor | None = None,
block_idx_last_scheduled_token: torch.Tensor | None = None, block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None,
@@ -730,12 +732,13 @@ def causal_conv1d_fn(
block_size_to_align // BLOCK_M, block_size_to_align // BLOCK_M,
# others # others
pad_slot_id, pad_slot_id,
null_block_id,
# META # META
HAS_BIAS=bias is not None, HAS_BIAS=bias is not None,
KERNEL_WIDTH=width, KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"], SILU_ACTIVATION=activation in ["silu", "swish"],
IS_APC_ENABLED=block_idx_last_scheduled_token is not None, IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
USE_PAD_SLOT=pad_slot_id is not None, HAS_NULL_BLOCK=null_block_id is not None,
NP2_STATELEN=np2_statelen, NP2_STATELEN=np2_statelen,
# launch_cooperative_grid=True # launch_cooperative_grid=True
BLOCK_M=BLOCK_M, BLOCK_M=BLOCK_M,
@@ -778,7 +781,7 @@ def _causal_conv1d_update_kernel(
stride_o_dim: tl.constexpr, stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr, stride_o_token: tl.constexpr,
# others # others
pad_slot_id: tl.constexpr, null_block_id: tl.constexpr,
# Meta-parameters # Meta-parameters
HAS_BIAS: tl.constexpr, HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr, KERNEL_WIDTH: tl.constexpr,
@@ -787,7 +790,7 @@ def _causal_conv1d_update_kernel(
IS_APC_ENABLED: tl.constexpr, IS_APC_ENABLED: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr, NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr, HAS_NULL_BLOCK: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
# ruff: noqa: E501 # ruff: noqa: E501
@@ -811,8 +814,8 @@ def _causal_conv1d_update_kernel(
conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init
).to(tl.int64) ).to(tl.int64)
if USE_PAD_SLOT: # noqa if HAS_NULL_BLOCK: # noqa
if conv_states_input_coord == pad_slot_id: if conv_states_input_coord == null_block_id:
# not processing as this is not the actual sequence # not processing as this is not the actual sequence
return return
@@ -1076,7 +1079,7 @@ def causal_conv1d_update(
num_accepted_tokens: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None,
query_start_loc: torch.Tensor | None = None, query_start_loc: torch.Tensor | None = None,
max_query_len: int = -1, max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID, null_block_id: int = NULL_BLOCK_ID,
block_idx_last_scheduled_token: torch.Tensor | None = None, block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None,
validate_data=False, validate_data=False,
@@ -1111,16 +1114,16 @@ def causal_conv1d_update(
max_query_len: int max_query_len: int
If query_start_loc is not None, this indicates the maximum query If query_start_loc is not None, this indicates the maximum query
length in the batch. length in the batch.
pad_slot_id: int null_block_id: int
if conv_state_indices is passed, lets the kernel identify padded Block ID used to identify padded entries in
entries that will not be processed, conv_state_indices. Block 0 is the null block.
for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] for example: conv_state_indices = [null_block_id, 1, 20, null_block_id]
in this case, the kernel will not process entries at in this case, the kernel will not process entries at
indices 0 and 3 indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
""" """
if validate_data: if validate_data:
assert pad_slot_id is not None assert null_block_id is not None
assert x.stride(1) == 1 assert x.stride(1) == 1
if isinstance(activation, bool): if isinstance(activation, bool):
activation = "silu" if activation is True else None activation = "silu" if activation is True else None
@@ -1225,7 +1228,7 @@ def causal_conv1d_update(
stride_o_dim, stride_o_dim,
stride_o_token, stride_o_token,
# others # others
pad_slot_id, null_block_id,
# META # META
HAS_BIAS=bias is not None, HAS_BIAS=bias is not None,
KERNEL_WIDTH=width, KERNEL_WIDTH=width,
@@ -1234,7 +1237,7 @@ def causal_conv1d_update(
IS_APC_ENABLED=block_idx_last_scheduled_token is not None, IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None, IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen, NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None, HAS_NULL_BLOCK=null_block_id is not None,
BLOCK_N=256, BLOCK_N=256,
) )
if unsqueeze: if unsqueeze:

View File

@@ -10,7 +10,7 @@ from packaging import version
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0")) TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
@@ -75,7 +75,7 @@ def _selective_scan_update_kernel(
out_ptr, out_ptr,
state_batch_indices_ptr, state_batch_indices_ptr,
dst_state_batch_indices_ptr, dst_state_batch_indices_ptr,
pad_slot_id, null_block_id,
num_accepted_tokens_ptr, num_accepted_tokens_ptr,
cu_seqlens_ptr, cu_seqlens_ptr,
# Matrix dimensions # Matrix dimensions
@@ -203,7 +203,7 @@ def _selective_scan_update_kernel(
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES: if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id mask &= state_batch_idx != null_block_id
state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32) state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)
if HAS_DT_BIAS: if HAS_DT_BIAS:
@@ -257,7 +257,7 @@ def _selective_scan_update_kernel(
if IS_SPEC_DECODING: if IS_SPEC_DECODING:
dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T
token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64) token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64)
if token_dst_idx != pad_slot_id: if token_dst_idx != null_block_id:
token_dst_ptrs = ( token_dst_ptrs = (
state_ptr_base state_ptr_base
+ token_dst_idx * stride_state_batch + token_dst_idx * stride_state_batch
@@ -329,7 +329,7 @@ def selective_state_update(
dt_softplus=False, dt_softplus=False,
state_batch_indices=None, state_batch_indices=None,
dst_state_batch_indices=None, dst_state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID, null_block_id=NULL_BLOCK_ID,
out=None, out=None,
num_accepted_tokens=None, num_accepted_tokens=None,
cu_seqlens=None, cu_seqlens=None,
@@ -348,12 +348,12 @@ def selective_state_update(
D: (dim,) or (nheads, dim) D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim) z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim) dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int null_block_id: int
if cache_indices is passed, lets the kernel identify padded if state_batch_indices is passed, lets the kernel identify
entries that will not be processed, padded entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] for example: state_batch_indices = [null_block_id, 1, 20,
in this case, the kernel will not process entries at null_block_id] in this case, the kernel will not process
indices 0 and 3 entries at indices 0 and 3
out: Preallocated ssm output tensor. Assume same shape as x. out: Preallocated ssm output tensor. Assume same shape as x.
In-place updated. In-place updated.
num_accepted_tokens: (batch,) num_accepted_tokens: (batch,)
@@ -488,7 +488,7 @@ def selective_state_update(
out, out,
state_batch_indices, state_batch_indices,
dst_state_batch_indices, dst_state_batch_indices,
pad_slot_id, null_block_id,
num_accepted_tokens, num_accepted_tokens,
cu_seqlens, cu_seqlens,
N, N,
@@ -550,7 +550,7 @@ def selective_scan_fn(
query_start_loc=None, query_start_loc=None,
cache_indices=None, cache_indices=None,
has_initial_state=None, has_initial_state=None,
pad_slot_id=PAD_SLOT_ID, null_block_id=NULL_BLOCK_ID,
block_size=1024, block_size=1024,
block_idx_first_scheduled_token=None, block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None, block_idx_last_scheduled_token=None,
@@ -588,10 +588,10 @@ def selective_scan_fn(
indicate if the ssm_state at the corresponding index should be indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes used as initial state. Not providing argument assumes
there's no initial state there's no initial state
pad_slot_id: int null_block_id: int
if cache_indices is passed, lets the kernel identify padding entries if cache_indices is passed, lets the kernel identify padding entries
that will not be processed, that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] for example: cache_indices = [null_block_id, 1 ,20 ,null_block_id]
in this case, the kernel will not process entries at indices 0 and 3 in this case, the kernel will not process entries at indices 0 and 3
block_size: int block_size: int
The block size to align the cached states to The block size to align the cached states to
@@ -643,7 +643,7 @@ def selective_scan_fn(
cache_indices, cache_indices,
has_initial_state, has_initial_state,
ssm_states, ssm_states,
pad_slot_id, null_block_id,
block_size, block_size,
block_idx_first_scheduled_token, block_idx_first_scheduled_token,
block_idx_last_scheduled_token, block_idx_last_scheduled_token,

View File

@@ -14,7 +14,7 @@ from vllm.v1.attention.backend import (
CommonAttentionMetadata, CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID, NULL_BLOCK_ID,
compute_causal_conv1d_metadata, compute_causal_conv1d_metadata,
mamba_get_block_table_tensor, mamba_get_block_table_tensor,
split_decodes_and_prefills, split_decodes_and_prefills,
@@ -341,7 +341,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_state_indices_tensor, non_blocking=True spec_state_indices_tensor, non_blocking=True
) )
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size] spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID) spec_state_indices_tensor[num_spec_decodes:].fill_(NULL_BLOCK_ID)
self.spec_sequence_masks[:num_spec_decodes].copy_( self.spec_sequence_masks[:num_spec_decodes].copy_(
spec_sequence_masks[:num_spec_decodes], non_blocking=True spec_sequence_masks[:num_spec_decodes], non_blocking=True
@@ -387,7 +387,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[ non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
:batch_size :batch_size
] ]
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID) non_spec_state_indices_tensor[num_decodes:].fill_(NULL_BLOCK_ID)
self.non_spec_query_start_loc[: num_decodes + 1].copy_( self.non_spec_query_start_loc[: num_decodes + 1].copy_(
non_spec_query_start_loc, non_blocking=True non_spec_query_start_loc, non_blocking=True

View File

@@ -15,7 +15,7 @@ from vllm.v1.attention.backend import (
CommonAttentionMetadata, CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID, NULL_BLOCK_ID,
compute_causal_conv1d_metadata, compute_causal_conv1d_metadata,
mamba_get_block_table_tensor, mamba_get_block_table_tensor,
split_decodes_and_prefills, split_decodes_and_prefills,
@@ -504,7 +504,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
state_indices_tensor_d, non_blocking=True state_indices_tensor_d, non_blocking=True
) )
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs] state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID state_indices_tensor_d[metadata.num_decodes :] = NULL_BLOCK_ID
if self.use_spec_decode and num_accepted_tokens is not None: if self.use_spec_decode and num_accepted_tokens is not None:
assert query_start_loc_d is not None assert query_start_loc_d is not None

View File

@@ -366,12 +366,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
seq_lens = common_attn_metadata.seq_lens[:num_decodes] seq_lens = common_attn_metadata.seq_lens[:num_decodes]
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...] block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
# Padded CUDA graph requests have block_table entries of -1.
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
# This is safe because padded requests have seq_lens=0, so the
# kernel produces no meaningful output for those rows.
block_table.clamp_(min=0)
max_decode_len = int(decode_lens_cpu.max().item()) max_decode_len = int(decode_lens_cpu.max().item())
next_n = 1 + self.num_speculative_tokens next_n = 1 + self.num_speculative_tokens
use_native = not self.use_flattening and max_decode_len == next_n use_native = not self.use_flattening and max_decode_len == next_n

View File

@@ -42,6 +42,7 @@ KVCacheLayoutType = Literal["NHD", "HND"]
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None _KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
PAD_SLOT_ID = -1 PAD_SLOT_ID = -1
NULL_BLOCK_ID = 0
def is_valid_kv_cache_layout(value: str) -> bool: def is_valid_kv_cache_layout(value: str) -> bool:

View File

@@ -122,6 +122,7 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
NULL_BLOCK_ID,
create_fast_prefill_custom_backend, create_fast_prefill_custom_backend,
get_dcp_local_seq_lens, get_dcp_local_seq_lens,
reorder_batch_to_split_decodes_and_prefills, reorder_batch_to_split_decodes_and_prefills,
@@ -2135,9 +2136,9 @@ class GPUModelRunner(
blk_table = self.input_batch.block_table[kv_cache_gid] blk_table = self.input_batch.block_table[kv_cache_gid]
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded) blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
# Fill unused with -1. Needed for reshape_and_cache in full cuda # Fill unused block table entries with NULL_BLOCK_ID (null block)
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID # for CUDAGraph padding. Block 0 is reserved for padding.
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) blk_table_tensor[num_reqs:num_reqs_padded].fill_(NULL_BLOCK_ID)
return blk_table_tensor return blk_table_tensor
assert slot_mappings is not None assert slot_mappings is not None