[Bugfix] Use null block (0) for padded block table entries (#35431)
Signed-off-by: SandishKumarHN <sandish@fb.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
@@ -122,7 +122,7 @@ def causal_conv1d_opcheck_fn(
|
||||
has_initial_state: torch.Tensor | None = None,
|
||||
conv_states: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
null_block_id: int = NULL_BLOCK_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
@@ -158,15 +158,16 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
|
||||
# +1 entry to reserve index 0 as null block
|
||||
conv_state = torch.randn(batch + 1, dim, width - 1, device=device, dtype=itype)
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state.detach().clone()
|
||||
# Start indices from 1, skipping null block at index 0
|
||||
conv_state_indices = torch.arange(1, batch + 1, dtype=torch.int32, device=device)
|
||||
conv_state_ref = conv_state[conv_state_indices].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
|
||||
|
||||
out = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
@@ -179,7 +180,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
|
||||
x_ref, conv_state_ref, weight, bias, activation=activation
|
||||
)
|
||||
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.equal(conv_state[conv_state_indices], conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@@ -215,7 +216,8 @@ def test_causal_conv1d_update_with_batch_gather(
|
||||
|
||||
x_ref = x.clone()
|
||||
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
# +1 to exclude index 0 (null block)
|
||||
conv_state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
@@ -223,7 +225,9 @@ def test_causal_conv1d_update_with_batch_gather(
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
conv_state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
torch.as_tensor(
|
||||
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||
),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
@@ -248,7 +252,6 @@ def test_causal_conv1d_update_with_batch_gather(
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
|
||||
@@ -317,13 +320,19 @@ def test_causal_conv1d_varlen(
|
||||
has_initial_states = torch.randint(
|
||||
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
|
||||
)
|
||||
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
|
||||
:batch_size
|
||||
]
|
||||
# +1 to exclude index 0 (null block)
|
||||
state_indices = (
|
||||
torch.randperm(total_entries - 1, dtype=torch.int32, device=x.device)[
|
||||
:batch_size
|
||||
]
|
||||
+ 1
|
||||
)
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
torch.as_tensor(
|
||||
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
@@ -336,7 +345,6 @@ def test_causal_conv1d_varlen(
|
||||
cache_indices=padded_state_indices,
|
||||
has_initial_state=has_initial_states,
|
||||
activation=activation,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
out_ref = []
|
||||
@@ -345,7 +353,7 @@ def test_causal_conv1d_varlen(
|
||||
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
|
||||
for i in range(len(seqlens[0])):
|
||||
x_s = [v[i].unsqueeze(0) for v in splits][0]
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
if padded_state_indices[i] == NULL_BLOCK_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
|
||||
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
|
||||
|
||||
|
||||
def selective_state_update_ref(
|
||||
@@ -179,7 +179,7 @@ def selective_scan_opcheck_fn(
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
ssm_states=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
null_block_id=NULL_BLOCK_ID,
|
||||
block_size=2048,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
@@ -229,7 +229,7 @@ def selective_scan_opcheck_fn(
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
null_block_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
@@ -351,7 +351,6 @@ def test_selective_scan(
|
||||
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
|
||||
if c > 0
|
||||
else None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
block_size=2048,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
@@ -652,15 +651,21 @@ def test_selective_scan_varlen(
|
||||
prev_state_shape, device=u.device, dtype=itype, requires_grad=False
|
||||
)
|
||||
prev_state_ref = prev_state.clone()
|
||||
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[
|
||||
:batch_size
|
||||
]
|
||||
# +1 to exclude index 0 (null block)
|
||||
state_indices = (
|
||||
torch.randperm(total_entries - 1, dtype=torch.int32, device=u.device)[
|
||||
:batch_size
|
||||
]
|
||||
+ 1
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
unused_states_bool[state_indices] = False
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
torch.as_tensor(
|
||||
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
@@ -690,7 +695,7 @@ def test_selective_scan_varlen(
|
||||
]
|
||||
for i in range(len(seqlens[0])):
|
||||
u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
if padded_state_indices[i] == NULL_BLOCK_ID:
|
||||
continue
|
||||
out_ref_s, _ = selective_scan_ref(
|
||||
u_s,
|
||||
@@ -758,7 +763,8 @@ def test_selective_state_update_with_batch_indices(
|
||||
padded_batch_size = batch_size + padding
|
||||
total_entries = 10 * batch_size
|
||||
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
# +1 to exclude index 0 (null block)
|
||||
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
@@ -766,7 +772,9 @@ def test_selective_state_update_with_batch_indices(
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
torch.as_tensor(
|
||||
[NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device
|
||||
),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
@@ -793,7 +801,6 @@ def test_selective_state_update_with_batch_indices(
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out,
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
@@ -849,7 +856,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
||||
state = torch.randn(
|
||||
total_entries, nheads, headdim, dstate, dtype=itype, device=device
|
||||
)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
# +1 to exclude index 0 (null block)
|
||||
state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
@@ -887,7 +895,6 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out,
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
@@ -935,17 +942,18 @@ def test_selective_state_update_with_num_accepted_tokens(
|
||||
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
|
||||
|
||||
state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
# Start from 1 to exclude null block at index 0
|
||||
initial_state_slots = torch.randint(
|
||||
0, 15, (batch_size,), device=device, dtype=torch.int32
|
||||
1, 15, (batch_size,), device=device, dtype=torch.int32
|
||||
)
|
||||
for seq_idx in range(batch_size):
|
||||
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
||||
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
|
||||
|
||||
dst_state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
slot_offset = 15
|
||||
dst_slots_map = {}
|
||||
@@ -1013,7 +1021,6 @@ def test_selective_state_update_with_num_accepted_tokens(
|
||||
state_batch_indices=state_batch_indices,
|
||||
dst_state_batch_indices=dst_state_batch_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
@@ -1061,18 +1068,19 @@ def test_selective_state_update_varlen_with_num_accepted(
|
||||
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
|
||||
|
||||
state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Start from 1 to exclude null block at index 0
|
||||
initial_state_slots = torch.randint(
|
||||
0, 15, (batch_size,), device=device, dtype=torch.int32
|
||||
1, 15, (batch_size,), device=device, dtype=torch.int32
|
||||
)
|
||||
for seq_idx in range(batch_size):
|
||||
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
|
||||
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
|
||||
|
||||
dst_state_batch_indices = torch.full(
|
||||
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
|
||||
(batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
slot_offset = 15
|
||||
@@ -1138,7 +1146,6 @@ def test_selective_state_update_varlen_with_num_accepted(
|
||||
state_batch_indices=state_batch_indices,
|
||||
dst_state_batch_indices=dst_state_batch_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
for seq_idx in range(batch_size):
|
||||
|
||||
Reference in New Issue
Block a user