[Kernel][Mamba] Optimize Mamba2 SSD prefill Triton kernels (#35397)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91
2026-03-04 20:47:17 +02:00
committed by GitHub
parent bc6be89d16
commit 7faba503c4
6 changed files with 155 additions and 369 deletions

View File

@@ -8,6 +8,7 @@ import torch
from packaging import version
from vllm import _custom_ops as ops
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
@@ -215,7 +216,7 @@ def _selective_scan_update_kernel(
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
dA = fast_exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
@@ -223,7 +224,7 @@ def _selective_scan_update_kernel(
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
dA = fast_exp(A * dt) # scalar, not a matrix
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)

View File

@@ -8,6 +8,7 @@
from packaging import version
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
from vllm.triton_utils import tl, triton
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
@@ -15,6 +16,76 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
@triton.autotune(
configs=[
# =================================================================
# Higher warp count configs for better latency hiding
# More warps = more instructions in flight = better memory latency hiding
# =================================================================
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
# Smaller tiles with more stages for software pipelining
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64},
num_stages=2,
num_warps=4,
),
# =================================================================
# Low register pressure configs (num_stages=1) for large dstate
# =================================================================
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=1,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=1,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=1,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=1,
num_warps=4,
),
# num_stages=2 configs - moderate register pressure
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
# Original configs for larger dstate values
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
@@ -200,7 +271,7 @@ def _chunk_scan_fwd_kernel(
offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
)
scale_m = tl.exp(dA_cs_m)
scale_m = fast_exp(dA_cs_m)
if BLOCK_SIZE_DSTATE <= 128:
C = tl.load(
C_ptrs,
@@ -285,7 +356,7 @@ def _chunk_scan_fwd_kernel(
)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
cb *= fast_exp(dA_cs_m[:, None] - dA_cs_k[None, :])
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:

View File

@@ -8,6 +8,7 @@
import torch
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
from vllm.triton_utils import tl, triton
from .mamba_ssm import softplus
@@ -116,6 +117,34 @@ def _chunk_cumsum_fwd_kernel(
@triton.autotune(
configs=[
# Small headdim/dstate configs (hdim<=64, dstate<=128) - increased parallelism
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=4,
),
# Low register pressure configs for large dstate (dstate=128)
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=2,
num_warps=4,
),
# original configs for larger headdim/dstate values
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
@@ -251,7 +280,7 @@ def _chunk_state_fwd_kernel(
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
tl.float32
)
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
scale = fast_exp(dA_cs_last - dA_cs_k) * dt_k
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
@@ -273,238 +302,6 @@ def _chunk_state_fwd_kernel(
tl.store(states_ptrs, states, mask=c_mask)
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2,
),
],
key=["hdim", "dstate", "chunk_size"],
)
@triton.jit
def _chunk_state_varlen_kernel(
# Pointers to matrices
x_ptr,
b_ptr,
dt_ptr,
dA_cumsum_ptr,
chunk_states_ptr,
cu_seqlens_ptr,
states_ptr,
initstates_ptr,
# Matrix dimensions
hdim: tl.constexpr,
dstate: tl.constexpr,
chunk_size: tl.constexpr,
nheads_ngroups_ratio: tl.constexpr,
# Strides
stride_x_seqlen: tl.int64,
stride_x_head: tl.int64,
stride_x_hdim: tl.constexpr,
stride_b_seqlen: tl.int64,
stride_b_head: tl.int64,
stride_b_dstate: tl.constexpr,
stride_dt_head: tl.int64,
stride_dt_chunk: tl.int64,
stride_dt_csize: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_chunk_states_chunk: tl.int64,
stride_chunk_states_head: tl.int64,
stride_chunk_states_hdim: tl.int64,
stride_chunk_states_dstate: tl.constexpr,
stride_states_batch: tl.int64,
stride_states_head: tl.int64,
stride_states_hdim: tl.int64,
stride_states_dstate: tl.constexpr,
stride_init_states_batch: tl.int64,
stride_init_states_head: tl.int64,
stride_init_states_hdim: tl.int64,
stride_init_states_dstate: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
HAS_INITSTATES: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
pid_c = (end_idx - 1) // chunk_size
b_ptr += (
pid_c * chunk_size * stride_b_seqlen
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
)
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
chunk_states_ptr += (
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
)
if HAS_INITSTATES:
# if there are init states provided, we differentiate between states (which
# are boundary conditions at a chunk boundary) and initstates (which are boundary
# conditions when a new example in a cont batch starts)
initstates_ptr += pid_h * stride_init_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
)
b_ptrs = b_ptr + (
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cs_last = tl.load(
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
).to(tl.float32)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
chunk_size_limit = end_idx - pid_c * chunk_size
start_idx = tl.load(cu_seqlens_ptr + pid_b)
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
x = tl.load(
x_ptrs,
mask=(offs_m[:, None] < hdim)
& (offs_k[None, :] < chunk_size_limit - k)
& (offs_k[None, :] >= start_idx_cur - k),
other=0.0,
)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k)
& (offs_n[None, :] < dstate)
& (offs_k[:, None] >= start_idx_cur - k),
other=0.0,
).to(tl.float32)
dA_cs_k = tl.load(
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
tl.float32
)
scale = tl.where(
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
tl.exp(dA_cs_last - dA_cs_k) * dt_k,
0.0,
)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
# If HAS_INITSTATES==True need to consider two possibilities
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
# - if state_idx >= pid * chunk_size, then we need to insert initstates
if (
(start_idx < pid_c * chunk_size) # first chunk
or (HAS_INITSTATES)
):
dA_cs_boundary = 0.0 # default
if not HAS_INITSTATES:
past_states_ptrs = chunk_states_ptr + (
offs_m[:, None] * stride_chunk_states_hdim
+ offs_n[None, :] * stride_chunk_states_dstate
)
else:
# - this seems repetitive, buts its to help the compiler
if start_idx < pid_c * chunk_size:
past_states_ptrs = chunk_states_ptr + (
offs_m[:, None] * stride_chunk_states_hdim
+ offs_n[None, :] * stride_chunk_states_dstate
)
else:
past_states_ptrs = initstates_ptr + (
pid_b * stride_init_states_batch
+ offs_m[:, None] * stride_init_states_hdim
+ offs_n[None, :] * stride_init_states_dstate
)
# need to adjust the boundary
if start_idx > pid_c * chunk_size:
dA_cs_boundary = tl.load(
dA_cumsum_ptr
+ (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
).to(tl.float32)
past_states = tl.load(
past_states_ptrs,
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
scale = tl.exp(dA_cs_last - dA_cs_boundary)
acc += past_states * scale
states = acc.to(states_ptr.dtype.element_ty)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
states_ptrs = states_ptr + (
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
)
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
tl.store(states_ptrs, states, mask=c_mask)
def _chunk_cumsum_fwd(
dt,
A,
@@ -612,89 +409,3 @@ def _chunk_state_fwd(
stride_dA_cs_csize=dA_cumsum.stride(2),
)
return states
def chunk_state_varlen(
B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None
):
total_seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = B.shape
batch = cu_seqlens.shape[0] - 1
cu_seqlens = cu_seqlens.contiguous()
assert nheads % ngroups == 0
assert B.shape == (total_seqlen, ngroups, dstate)
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
if initial_states is not None:
assert initial_states.shape == (batch, nheads, headdim, dstate)
states = torch.empty(
batch,
nheads,
headdim,
dstate,
dtype=chunk_states.dtype,
device=chunk_states.device,
)
initial_states_strides = (
(
initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3),
)
if initial_states is not None
else (0, 0, 0, 0)
)
grid = lambda META: (
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
batch,
nheads,
)
with torch.cuda.device(x.device.index):
_chunk_state_varlen_kernel[grid](
x_ptr=x,
b_ptr=B,
dt_ptr=dt,
dA_cumsum_ptr=dA_cumsum,
chunk_states_ptr=chunk_states,
cu_seqlens_ptr=cu_seqlens,
states_ptr=states,
initstates_ptr=initial_states,
hdim=headdim,
dstate=dstate,
chunk_size=chunk_size,
nheads_ngroups_ratio=nheads // ngroups,
stride_x_seqlen=x.stride(0),
stride_x_head=x.stride(1),
stride_x_hdim=x.stride(2),
stride_b_seqlen=B.stride(0),
stride_b_head=B.stride(1),
stride_b_dstate=B.stride(2),
stride_dt_head=dt.stride(0),
stride_dt_chunk=dt.stride(1),
stride_dt_csize=dt.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
stride_chunk_states_chunk=chunk_states.stride(0),
stride_chunk_states_head=chunk_states.stride(1),
stride_chunk_states_hdim=chunk_states.stride(2),
stride_chunk_states_dstate=chunk_states.stride(3),
stride_states_batch=states.stride(0),
stride_states_head=states.stride(1),
stride_states_hdim=states.stride(2),
stride_states_dstate=states.stride(3),
stride_init_states_batch=initial_states_strides[0],
stride_init_states_head=initial_states_strides[1],
stride_init_states_hdim=initial_states_strides[2],
stride_init_states_dstate=initial_states_strides[3],
HAS_INITSTATES=initial_states is not None,
)
return states

View File

@@ -107,18 +107,15 @@ def _mamba_chunk_scan_combined_fwd(
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states and
# ii) seq_idx to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - parallelized across sequences using last_chunk_indices to derive
# per-sequence chunk ranges. Each sequence's state passing runs independently.
states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum, # (nheads, nchunks, chunk_size)
cu_chunk_seqlens,
last_chunk_indices,
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None
else None, # (batch, nheads, headdim*dstate)
seq_idx=seq_idx,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
)
states = rearrange(states, "... (p n) -> ... p n", n=dstate)

View File

@@ -8,6 +8,7 @@
import torch
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
from vllm.triton_utils import tl, triton
@@ -29,12 +30,9 @@ def _state_passing_fwd_kernel(
out_ptr,
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
cu_chunk_seqlens_ptr,
last_chunk_indices_ptr,
# Matrix dimensions
dim: tl.constexpr,
nchunks,
seqlen,
chunk_size: tl.constexpr,
# Strides
stride_states_chunk: tl.int64,
@@ -49,55 +47,51 @@ def _state_passing_fwd_kernel(
stride_initstates_batch: tl.int64,
stride_initstates_head: tl.int64,
stride_initstates_dim: tl.constexpr,
stride_seq_idx_chunk: tl.constexpr,
# Meta-parameters
HAS_INITSTATES: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_h = tl.program_id(axis=1)
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
states_ptr += pid_h * stride_states_head
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize
out_ptr += pid_h * stride_out_head
# Derive this sequence's chunk range from last_chunk_indices
chunk_end = tl.load(last_chunk_indices_ptr + pid_b) + 1
chunk_start = (
tl.load(last_chunk_indices_ptr + pid_b - 1, mask=pid_b > 0, other=-1) + 1
)
# Offset pointers to this sequence's first chunk
states_ptr += chunk_start * stride_states_chunk + pid_h * stride_states_head
dA_cs_ptr += (
pid_h * stride_dA_cs_head
+ chunk_start * stride_dA_cs_chunk
+ (chunk_size - 1) * stride_dA_cs_csize
)
out_ptr += chunk_start * stride_out_chunk + pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
states_ptrs = states_ptr + offs_m * stride_states_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
# Load initial state once — no per-chunk branching needed
if HAS_INITSTATES:
initstates_ptrs = (
initstates_ptr
+ pid_b * stride_initstates_batch
+ pid_h * stride_initstates_head
+ offs_m * stride_initstates_dim
)
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
else:
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
prev_seq_idx = 0
for c in range(nchunks):
# Loop over only this sequence's chunks — branchless
nchunks_this_seq = chunk_end - chunk_start
for _ in range(nchunks_this_seq):
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk)
# we have started a new sequence
if prev_seq_idx != seq_idx:
if HAS_INITSTATES:
initstates_ptrs = (
initstates_ptr
+ seq_idx * stride_initstates_batch
+ pid_h * stride_initstates_head
+ offs_m * stride_initstates_dim
)
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
tl.float32
)
else:
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
prev_seq_idx = seq_idx
states = tl.exp(dA_cs) * states + new_states
states = fast_exp(dA_cs) * states + new_states
tl.store(out_ptrs, states, mask=offs_m < dim)
states_ptrs += stride_states_chunk
@@ -108,15 +102,14 @@ def _state_passing_fwd_kernel(
def _state_passing_fwd(
states,
dA_cumsum,
cu_chunk_seqlens,
seq_idx,
last_chunk_indices,
initial_states=None,
out_dtype=None,
):
nchunks, nheads, dim = states.shape
chunk_size = dA_cumsum.shape[-1]
batch = last_chunk_indices.shape[0]
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
seqlen = seq_idx.shape[-1]
out_dtype = states.dtype if out_dtype is None else out_dtype
out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype)
@@ -126,19 +119,16 @@ def _state_passing_fwd(
else (0, 0, 0)
)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads)
with torch.cuda.device(states.device.index):
_state_passing_fwd_kernel[grid](
states_ptr=states,
out_ptr=out,
dA_cs_ptr=dA_cumsum,
initstates_ptr=initial_states,
seq_idx_ptr=seq_idx,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
last_chunk_indices_ptr=last_chunk_indices,
dim=dim,
nchunks=nchunks,
seqlen=seqlen if seq_idx is not None else 0,
chunk_size=chunk_size if seq_idx is not None else 0,
chunk_size=chunk_size,
stride_states_chunk=states.stride(0),
stride_states_head=states.stride(1),
stride_states_dim=states.stride(2),
@@ -151,7 +141,6 @@ def _state_passing_fwd(
stride_initstates_batch=initial_states_strides[0],
stride_initstates_head=initial_states_strides[1],
stride_initstates_dim=initial_states_strides[2],
stride_seq_idx_chunk=seq_idx.stride(0),
HAS_INITSTATES=initial_states is not None,
)
return out

View File

@@ -0,0 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.triton_utils import tl, triton
@triton.jit
def fast_exp(x):
"""Faster alternative to tl.exp() using the hardware exp2 instruction.
tl.math.exp2 maps directly to a single ex2.approx.f32 PTX instruction,
while tl.exp goes through libdevice __nv_expf which adds function call
overhead and extra range checking.
"""
# exp(x) = exp2(x * log2(e)), where log2(e) = 1/ln(2) = 1.4426950408889634
LOG2E = tl.constexpr(1.4426950408889634)
return tl.math.exp2(LOG2E * x)