[Bugfix] Use null block (0) for padded block table entries (#35431)

Signed-off-by: SandishKumarHN <sandish@fb.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
SandishKumarHN
2026-03-30 14:02:51 -07:00
committed by GitHub
parent 1fc69f59bb
commit bcc6f67447
15 changed files with 125 additions and 100 deletions

View File

@@ -14,7 +14,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
def selective_state_update_ref(
@@ -179,7 +179,7 @@ def selective_scan_opcheck_fn(
cache_indices=None,
has_initial_state=None,
ssm_states=None,
pad_slot_id=PAD_SLOT_ID,
null_block_id=NULL_BLOCK_ID,
block_size=2048,
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
@@ -229,7 +229,7 @@ def selective_scan_opcheck_fn(
cache_indices,
has_initial_state,
ssm_states,
pad_slot_id,
null_block_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
@@ -351,7 +351,6 @@ def test_selective_scan(
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
if c > 0
else None,
pad_slot_id=PAD_SLOT_ID,
block_size=2048,
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
@@ -652,15 +651,21 @@ def test_selective_scan_varlen(
prev_state_shape, device=u.device, dtype=itype, requires_grad=False
)
prev_state_ref = prev_state.clone()
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[
:batch_size
]
# +1 to exclude index 0 (null block)
state_indices = (
torch.randperm(total_entries - 1, dtype=torch.int32, device=u.device)[
:batch_size
]
+ 1
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
],
dim=-1,
)
@@ -690,7 +695,7 @@ def test_selective_scan_varlen(
]
for i in range(len(seqlens[0])):
u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
if padded_state_indices[i] == PAD_SLOT_ID:
if padded_state_indices[i] == NULL_BLOCK_ID:
continue
out_ref_s, _ = selective_scan_ref(
u_s,
@@ -758,7 +763,8 @@ def test_selective_state_update_with_batch_indices(
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(
# +1 to exclude index 0 (null block)
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
@@ -766,7 +772,9 @@ def test_selective_state_update_with_batch_indices(
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
],
dim=0,
)
@@ -793,7 +801,6 @@ def test_selective_state_update_with_batch_indices(
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
@@ -849,7 +856,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
state = torch.randn(
total_entries, nheads, headdim, dstate, dtype=itype, device=device
)
state_indices = torch.randperm(total_entries)[:batch_size].to(
# +1 to exclude index 0 (null block)
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
dtype=torch.int32, device=device
)
@@ -887,7 +895,6 @@ def test_selective_state_update_with_heads_with_batch_indices(
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
@@ -935,17 +942,18 @@ def test_selective_state_update_with_num_accepted_tokens(
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
)
# Start from 1 to exclude null block at index 0
initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32
1, 15, (batch_size,), device=device, dtype=torch.int32
)
for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
)
slot_offset = 15
dst_slots_map = {}
@@ -1013,7 +1021,6 @@ def test_selective_state_update_with_num_accepted_tokens(
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@@ -1061,18 +1068,19 @@ def test_selective_state_update_varlen_with_num_accepted(
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
)
# Start from 1 to exclude null block at index 0
initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32
1, 15, (batch_size,), device=device, dtype=torch.int32
)
for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
)
slot_offset = 15
@@ -1138,7 +1146,6 @@ def test_selective_state_update_varlen_with_num_accepted(
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
)
for seq_idx in range(batch_size):