Add SpecDec support to selective_state_update (#29488)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2025-12-08 23:45:18 +02:00
committed by GitHub
parent 799804d140
commit ae0f69b16a
2 changed files with 505 additions and 72 deletions

View File

@@ -425,6 +425,80 @@ def test_selective_state_update(dim, dstate, has_z, itype):
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
@pytest.mark.parametrize("max_seq_len", [1, 2, 4])
def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 5e-2, 1.5e-1
if torch.version.hip:
atol *= 2
# set seed
current_platform.seed_everything(0)
batch_size = 4
token_counts = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
total_tokens = int(token_counts.sum().item())
cu_seqlens = torch.tensor(
[0] + torch.cumsum(token_counts, dim=0).tolist(),
dtype=torch.int32,
device=device,
)
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(total_tokens, dstate, device=device)
C = torch.randn(total_tokens, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state.detach().clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
cu_seqlens=cu_seqlens,
)
out_ref_list = []
for seq_idx in range(batch_size):
start_idx = cu_seqlens[seq_idx].item()
end_idx = cu_seqlens[seq_idx + 1].item()
num_tokens = end_idx - start_idx
for token_idx in range(num_tokens):
idx = start_idx + token_idx
out_ref_list.append(
selective_state_update_ref(
state_ref[seq_idx : seq_idx + 1],
x[idx : idx + 1],
dt[idx : idx + 1],
A,
B[idx : idx + 1],
C[idx : idx + 1],
D=D,
z=z[idx : idx + 1] if has_z else None,
dt_bias=dt_bias,
dt_softplus=True,
)
)
out_ref = torch.cat(out_ref_list, dim=0)
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("wtype", [torch.float32])
@pytest.mark.parametrize("itype", [torch.float32])
@pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096])
@@ -766,3 +840,254 @@ def test_selective_state_update_with_heads_with_batch_indices(
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
@pytest.mark.parametrize("max_seq_len", [2, 4])
def test_selective_state_update_with_num_accepted_tokens(
dim, dstate, has_z, itype, max_seq_len
):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 5e-2, 1.5e-1
if torch.version.hip:
atol *= 2
current_platform.seed_everything(0)
batch_size = 4
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
total_tokens = int(tokens_per_seq.sum().item())
num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens
num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted
cu_seqlens = torch.tensor(
[0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
dtype=torch.int32,
device=device,
)
total_state_slots = 50
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32
)
for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
slot_offset = 15
dst_slots_map = {}
for seq_idx in range(batch_size):
for token_idx in range(tokens_per_seq[seq_idx].item()):
dst_state_batch_indices[seq_idx, token_idx] = slot_offset
dst_slots_map[(seq_idx, token_idx)] = slot_offset
slot_offset += 1
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(total_tokens, dstate, device=device)
C = torch.randn(total_tokens, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref_intermediate = {}
out_ref_list = []
for seq_idx in range(batch_size):
seq_start = cu_seqlens[seq_idx].item()
seq_end = cu_seqlens[seq_idx + 1].item()
num_tokens = seq_end - seq_start
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
initial_slot = state_batch_indices[seq_idx, token_pos].item()
state_seq = state[initial_slot : initial_slot + 1].clone()
for token_idx in range(num_tokens):
global_idx = seq_start + token_idx
out_token = selective_state_update_ref(
state_seq,
x[global_idx : global_idx + 1],
dt[global_idx : global_idx + 1],
A,
B[global_idx : global_idx + 1],
C[global_idx : global_idx + 1],
D=D,
z=z[global_idx : global_idx + 1] if has_z else None,
dt_bias=dt_bias,
dt_softplus=True,
)
out_ref_list.append(out_token)
state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()
out_ref = torch.cat(out_ref_list, dim=0)
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
cu_seqlens=cu_seqlens,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
for seq_idx in range(batch_size):
num_tokens = tokens_per_seq[seq_idx].item()
for token_idx in range(num_tokens):
dst_slot = dst_slots_map[(seq_idx, token_idx)]
state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)
assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
@pytest.mark.parametrize("max_seq_len", [2, 4])
def test_selective_state_update_varlen_with_num_accepted(
dim, dstate, has_z, itype, max_seq_len
):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 5e-2, 1.5e-1
if torch.version.hip:
atol *= 2
current_platform.seed_everything(0)
batch_size = 4
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
total_tokens = int(tokens_per_seq.sum().item())
num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens
num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted
cu_seqlens = torch.tensor(
[0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
dtype=torch.int32,
device=device,
)
total_state_slots = 50
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
initial_state_slots = torch.randint(
0, 15, (batch_size,), device=device, dtype=torch.int32
)
for seq_idx in range(batch_size):
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
dst_state_batch_indices = torch.full(
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
)
slot_offset = 15
dst_slots_map = {}
for seq_idx in range(batch_size):
for token_idx in range(tokens_per_seq[seq_idx].item()):
dst_state_batch_indices[seq_idx, token_idx] = slot_offset
dst_slots_map[(seq_idx, token_idx)] = slot_offset
slot_offset += 1
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(total_tokens, dstate, device=device)
C = torch.randn(total_tokens, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref_intermediate = {}
for seq_idx in range(batch_size):
seq_start = cu_seqlens[seq_idx].item()
seq_end = cu_seqlens[seq_idx + 1].item()
num_tokens = seq_end - seq_start
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
initial_slot = state_batch_indices[seq_idx, token_pos].item()
state_seq = state[initial_slot : initial_slot + 1].clone()
for token_idx in range(num_tokens):
global_idx = seq_start + token_idx
selective_state_update_ref(
state_seq,
x[global_idx : global_idx + 1],
dt[global_idx : global_idx + 1],
A,
B[global_idx : global_idx + 1],
C[global_idx : global_idx + 1],
D=D,
z=z[global_idx : global_idx + 1] if has_z else None,
dt_bias=dt_bias,
dt_softplus=True,
)
state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
cu_seqlens=cu_seqlens,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
num_accepted_tokens=num_accepted_tokens,
pad_slot_id=PAD_SLOT_ID,
)
for seq_idx in range(batch_size):
num_tokens = tokens_per_seq[seq_idx].item()
for token_idx in range(num_tokens):
dst_slot = dst_slots_map[(seq_idx, token_idx)]
state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)
assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)