[Bugfix] Use null block (0) for padded block table entries (#35431)

Signed-off-by: SandishKumarHN <sandish@fb.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
SandishKumarHN
2026-03-30 14:02:51 -07:00
committed by GitHub
parent 1fc69f59bb
commit bcc6f67447
15 changed files with 125 additions and 100 deletions

View File

@@ -12,7 +12,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_update,
)
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
def causal_conv1d_ref(
@@ -122,7 +122,7 @@ def causal_conv1d_opcheck_fn(
has_initial_state: torch.Tensor | None = None,
conv_states: torch.Tensor | None = None,
activation: str | None = "silu",
pad_slot_id: int = PAD_SLOT_ID,
null_block_id: int = NULL_BLOCK_ID,
):
"""
x: (batch, dim, seqlen)
@@ -158,15 +158,16 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone()
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
# +1 entry to reserve index 0 as null block
conv_state = torch.randn(batch + 1, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
# Start indices from 1, skipping null block at index 0
conv_state_indices = torch.arange(1, batch + 1, dtype=torch.int32, device=device)
conv_state_ref = conv_state[conv_state_indices].detach().clone()
activation = None if not silu_activation else "silu"
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
out = causal_conv1d_update(
x,
conv_state,
@@ -179,7 +180,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
x_ref, conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state, conv_state_ref)
assert torch.equal(conv_state[conv_state_indices], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@@ -215,7 +216,8 @@ def test_causal_conv1d_update_with_batch_gather(
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
# +1 to exclude index 0 (null block)
conv_state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
@@ -223,7 +225,9 @@ def test_causal_conv1d_update_with_batch_gather(
padded_state_indices = torch.concat(
[
conv_state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
],
dim=0,
)
@@ -248,7 +252,6 @@ def test_causal_conv1d_update_with_batch_gather(
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = causal_conv1d_update_ref(
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
@@ -317,13 +320,19 @@ def test_causal_conv1d_varlen(
has_initial_states = torch.randint(
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
)
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
:batch_size
]
# +1 to exclude index 0 (null block)
state_indices = (
torch.randperm(total_entries - 1, dtype=torch.int32, device=x.device)[
:batch_size
]
+ 1
)
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
],
dim=-1,
)
@@ -336,7 +345,6 @@ def test_causal_conv1d_varlen(
cache_indices=padded_state_indices,
has_initial_state=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = []
@@ -345,7 +353,7 @@ def test_causal_conv1d_varlen(
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
if padded_state_indices[i] == NULL_BLOCK_ID:
continue
out_ref_b.append(
causal_conv1d_ref(

View File

@@ -14,7 +14,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
def selective_state_update_ref(
@@ -179,7 +179,7 @@ def selective_scan_opcheck_fn(
cache_indices=None,
has_initial_state=None,
ssm_states=None,
pad_slot_id=PAD_SLOT_ID,
null_block_id=NULL_BLOCK_ID,
block_size=2048,
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
@@ -229,7 +229,7 @@ def selective_scan_opcheck_fn(
cache_indices,
has_initial_state,
ssm_states,
pad_slot_id,
null_block_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
@@ -351,7 +351,6 @@ def test_selective_scan(
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
if c > 0
else None,
pad_slot_id=PAD_SLOT_ID,
block_size=2048,
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
@@ -652,15 +651,21 @@ def test_selective_scan_varlen(
prev_state_shape, device=u.device, dtype=itype, requires_grad=False
)
prev_state_ref = prev_state.clone()
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[
:batch_size
]
# +1 to exclude index 0 (null block)
state_indices = (
torch.randperm(total_entries - 1, dtype=torch.int32, device=u.device)[
:batch_size
]
+ 1
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
],
dim=-1,
)
@@ -690,7 +695,7 @@ def test_selective_scan_varlen(
]
for i in range(len(seqlens[0])):
u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
if padded_state_indices[i] == PAD_SLOT_ID:
if padded_state_indices[i] == NULL_BLOCK_ID:
continue
out_ref_s, _ = selective_scan_ref(
u_s,
@@ -758,7 +763,8 @@ def test_selective_state_update_with_batch_indices(
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(
# +1 to exclude index 0 (null block)
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
@@ -766,7 +772,9 @@ def test_selective_state_update_with_batch_indices(
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
],
dim=0,
)
@@ -793,7 +801,6 @@ def test_selective_state_update_with_batch_indices(
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
@@ -849,7 +856,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
state = torch.randn(
total_entries, nheads, headdim, dstate, dtype=itype, device=device
)
state_indices = torch.randperm(total_entries)[:batch_size].to(
# +1 to exclude index 0 (null block)
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
dtype=torch.int32, device=device
)
@@ -887,7 +895,6 @@ def test_selective_state_update_with_heads_with_batch_indices(
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
@@ -935,17 +942,18 @@ def test_selective_state_update_with_num_accepted_tokens(
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
)
# Start from 1 to exclude null block at index 0
initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32
1, 15, (batch_size,), device=device, dtype=torch.int32
)
for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
)
slot_offset = 15
dst_slots_map = {}
@@ -1013,7 +1021,6 @@ def test_selective_state_update_with_num_accepted_tokens(
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@@ -1061,18 +1068,19 @@ def test_selective_state_update_varlen_with_num_accepted(
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
)
# Start from 1 to exclude null block at index 0
initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32
1, 15, (batch_size,), device=device, dtype=torch.int32
)
for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
)
slot_offset = 15
@@ -1138,7 +1146,6 @@ def test_selective_state_update_varlen_with_num_accepted(
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
)
for seq_idx in range(batch_size):