Add SpecDec support to selective_state_update (#29488)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2025-12-08 23:45:18 +02:00
committed by GitHub
parent 799804d140
commit ae0f69b16a
2 changed files with 505 additions and 72 deletions

View File

@@ -36,10 +36,14 @@ else:
is not None
}
)
@triton.heuristics(
{"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens_ptr"] is not None}
)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens_ptr"] is not None})
@triton.heuristics(
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
)
@triton.jit
@triton.jit(do_not_specialize=["N"])
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr,
@@ -55,8 +59,10 @@ def _selective_scan_update_kernel(
state_batch_indices_ptr,
dst_state_batch_indices_ptr,
pad_slot_id,
num_accepted_tokens_ptr,
cu_seqlens_ptr,
# Matrix dimensions
batch,
N,
nheads,
dim,
dstate,
@@ -91,6 +97,10 @@ def _selective_scan_update_kernel(
stride_out_batch,
stride_out_head,
stride_out_dim,
stride_state_indices_batch,
stride_state_indices_T,
stride_dst_state_indices_batch,
stride_dst_state_indices_T,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
TIE_HDIM: tl.constexpr,
@@ -99,22 +109,50 @@ def _selective_scan_update_kernel(
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_STATE_BATCH_INDICES: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
IS_VARLEN: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
if IS_VARLEN:
bos = tl.load(cu_seqlens_ptr + pid_b).to(tl.int64)
eos = tl.load(cu_seqlens_ptr + pid_b + 1).to(tl.int64)
seq_len = eos - bos
if seq_len == 0:
return
else:
bos = pid_b
seq_len = 1
state_ptr_base = state_ptr
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
dst_state_batch_indices_ptr += pid_b
dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64)
dst_state_ptr = state_ptr + (
dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head
if IS_SPEC_DECODING:
num_accepted = tl.load(num_accepted_tokens_ptr + pid_b).to(tl.int64)
init_token_idx = tl.maximum(num_accepted - 1, 0)
else:
init_token_idx = 0
dst_state_batch_indices_ptr += pid_b * stride_dst_state_indices_batch
if not IS_SPEC_DECODING:
dst_state_batch_idx = tl.load(
dst_state_batch_indices_ptr
+ init_token_idx * stride_dst_state_indices_T
).to(tl.int64)
dst_state_ptr = state_ptr + (
dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head
)
state_batch_indices_ptr += (
pid_b * stride_state_indices_batch + init_token_idx * stride_state_indices_T
)
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
else:
@@ -123,86 +161,112 @@ def _selective_scan_update_kernel(
)
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
x_ptr += bos * stride_x_batch + pid_h * stride_x_head
dt_ptr += bos * stride_dt_batch + pid_h * stride_dt_head
if HAS_DT_BIAS:
dt_bias_ptr += pid_h * stride_dt_bias_head
A_ptr += pid_h * stride_A_head
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
B_ptr += bos * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
C_ptr += bos * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
if HAS_Z:
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
z_ptr += bos * stride_z_batch + pid_h * stride_z_head
out_ptr += bos * stride_out_batch + pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
dst_state_ptrs = dst_state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
if not IS_SPEC_DECODING:
dst_state_ptrs = dst_state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
if HAS_D:
D_ptr += pid_h * stride_D_head
A_ptrs = A_ptr + (
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
)
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_D:
D_ptrs = D_ptr + offs_m * stride_D_dim
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
state = tl.load(state_ptrs, mask=mask, other=0.0)
A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
for i_t in range(seq_len):
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(
A_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
tl.store(dst_state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
if IS_SPEC_DECODING:
dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T
token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64)
if token_dst_idx != pad_slot_id:
token_dst_ptrs = (
state_ptr_base
+ token_dst_idx * stride_state_batch
+ pid_h * stride_state_head
+ offs_m[:, None] * stride_state_dim
+ offs_n[None, :] * stride_state_dstate
)
tl.store(
token_dst_ptrs, state.to(token_dst_ptrs.dtype.element_ty), mask=mask
)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
x_ptr += stride_x_batch
dt_ptr += stride_dt_batch
B_ptr += stride_B_batch
C_ptr += stride_C_batch
out_ptr += stride_out_batch
if HAS_Z:
z_ptr += stride_z_batch
if not IS_SPEC_DECODING:
tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask)
def selective_state_update(
@@ -220,6 +284,8 @@ def selective_state_update(
dst_state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID,
out=None,
num_accepted_tokens=None,
cu_seqlens=None,
):
"""
Argument:
@@ -240,6 +306,11 @@ def selective_state_update(
indices 0 and 3
out: Preallocated ssm output tensor. Assume same shape as x.
In-place updated.
num_accepted_tokens: (batch,)
number of accepted tokens from previous verification step,
tells the kernel which initial state to use
cu_seqlens: (batch,)
length per sequence, for variable length in speculative decoding cases
"""
if state.dim() == 3:
state = state.unsqueeze(1)
@@ -261,9 +332,26 @@ def selective_state_update(
dt_bias = dt_bias.unsqueeze(0)
if out.dim() == 2:
out = out.unsqueeze(1)
if num_accepted_tokens is not None:
assert state_batch_indices is not None and state_batch_indices.dim() == 2
assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2
if state_batch_indices is not None and state_batch_indices.dim() == 1:
state_batch_indices = state_batch_indices.unsqueeze(1)
if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1:
dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1)
_, nheads, dim, dstate = state.shape
batch = x.shape[0]
if cu_seqlens is not None:
N = len(cu_seqlens) - 1
# Only used to verify the shape of
# state_batch_indices and dst_state_batch_indices
max_seqlen = (
state_batch_indices.size(-1) if state_batch_indices is not None else 1
)
else:
N = batch
max_seqlen = 1
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
@@ -279,16 +367,30 @@ def selective_state_update(
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch,)
assert state_batch_indices.shape[0] >= N
assert state_batch_indices.shape[1] >= max_seqlen
if dst_state_batch_indices is not None:
assert dst_state_batch_indices.shape == (batch,)
assert dst_state_batch_indices.shape[0] >= N
assert dst_state_batch_indices.shape[1] >= max_seqlen
else:
# revert to the default behavior of in-place state updates
dst_state_batch_indices = state_batch_indices
assert out.shape == x.shape
if num_accepted_tokens is not None:
assert num_accepted_tokens.shape == (N,)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), N, nheads)
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
state_batch_indices_strides = (
(state_batch_indices.stride(0), state_batch_indices.stride(1))
if state_batch_indices is not None
else (0, 0)
)
dst_state_batch_indices_strides = (
(dst_state_batch_indices.stride(0), dst_state_batch_indices.stride(1))
if dst_state_batch_indices is not None
else (0, 0)
)
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = (
@@ -321,7 +423,9 @@ def selective_state_update(
state_batch_indices,
dst_state_batch_indices,
pad_slot_id,
batch,
num_accepted_tokens,
cu_seqlens,
N,
nheads,
dim,
dstate,
@@ -353,6 +457,10 @@ def selective_state_update(
out.stride(0),
out.stride(1),
out.stride(2),
state_batch_indices_strides[0],
state_batch_indices_strides[1],
dst_state_batch_indices_strides[0],
dst_state_batch_indices_strides[1],
dt_softplus,
tie_hdim,
BLOCK_SIZE_M,