[Bugfix] Use null block (0) for padded block table entries (#35431)

Signed-off-by: SandishKumarHN <sandish@fb.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
SandishKumarHN
2026-03-30 14:02:51 -07:00
committed by GitHub
parent 1fc69f59bb
commit bcc6f67447
15 changed files with 125 additions and 100 deletions

View File

@@ -12,7 +12,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_update,
)
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
def causal_conv1d_ref(
@@ -122,7 +122,7 @@ def causal_conv1d_opcheck_fn(
has_initial_state: torch.Tensor | None = None,
conv_states: torch.Tensor | None = None,
activation: str | None = "silu",
pad_slot_id: int = PAD_SLOT_ID,
null_block_id: int = NULL_BLOCK_ID,
):
"""
x: (batch, dim, seqlen)
@@ -158,15 +158,16 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone()
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
# +1 entry to reserve index 0 as null block
conv_state = torch.randn(batch + 1, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
# Start indices from 1, skipping null block at index 0
conv_state_indices = torch.arange(1, batch + 1, dtype=torch.int32, device=device)
conv_state_ref = conv_state[conv_state_indices].detach().clone()
activation = None if not silu_activation else "silu"
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
out = causal_conv1d_update(
x,
conv_state,
@@ -179,7 +180,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
x_ref, conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state, conv_state_ref)
assert torch.equal(conv_state[conv_state_indices], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@@ -215,7 +216,8 @@ def test_causal_conv1d_update_with_batch_gather(
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
# +1 to exclude index 0 (null block)
conv_state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
@@ -223,7 +225,9 @@ def test_causal_conv1d_update_with_batch_gather(
padded_state_indices = torch.concat(
[
conv_state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
],
dim=0,
)
@@ -248,7 +252,6 @@ def test_causal_conv1d_update_with_batch_gather(
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = causal_conv1d_update_ref(
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
@@ -317,13 +320,19 @@ def test_causal_conv1d_varlen(
has_initial_states = torch.randint(
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
)
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
:batch_size
]
# +1 to exclude index 0 (null block)
state_indices = (
torch.randperm(total_entries - 1, dtype=torch.int32, device=x.device)[
:batch_size
]
+ 1
)
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
torch.as_tensor(
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
),
],
dim=-1,
)
@@ -336,7 +345,6 @@ def test_causal_conv1d_varlen(
cache_indices=padded_state_indices,
has_initial_state=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = []
@@ -345,7 +353,7 @@ def test_causal_conv1d_varlen(
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
if padded_state_indices[i] == NULL_BLOCK_ID:
continue
out_ref_b.append(
causal_conv1d_ref(