[Kernel][Mamba] Optimize Mamba2 SSD prefill Triton kernels (#35397)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
@@ -215,7 +216,7 @@ def _selective_scan_update_kernel(
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
dA = fast_exp(A * dt[:, None])
|
||||
else:
|
||||
dt = tl.load(dt_ptr).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
@@ -223,7 +224,7 @@ def _selective_scan_update_kernel(
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptr).to(tl.float32)
|
||||
dA = tl.exp(A * dt) # scalar, not a matrix
|
||||
dA = fast_exp(A * dt) # scalar, not a matrix
|
||||
|
||||
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
@@ -15,6 +16,76 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# =================================================================
|
||||
# Higher warp count configs for better latency hiding
|
||||
# More warps = more instructions in flight = better memory latency hiding
|
||||
# =================================================================
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
# Smaller tiles with more stages for software pipelining
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
# =================================================================
|
||||
# Low register pressure configs (num_stages=1) for large dstate
|
||||
# =================================================================
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
# num_stages=2 configs - moderate register pressure
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
# Original configs for larger dstate values
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
@@ -200,7 +271,7 @@ def _chunk_scan_fwd_kernel(
|
||||
offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
|
||||
)
|
||||
|
||||
scale_m = tl.exp(dA_cs_m)
|
||||
scale_m = fast_exp(dA_cs_m)
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
C = tl.load(
|
||||
C_ptrs,
|
||||
@@ -285,7 +356,7 @@ def _chunk_scan_fwd_kernel(
|
||||
)
|
||||
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
|
||||
# So we don't need masking wrt seq_idx here.
|
||||
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
|
||||
cb *= fast_exp(dA_cs_m[:, None] - dA_cs_k[None, :])
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
|
||||
cb *= dt_k
|
||||
if IS_CAUSAL:
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .mamba_ssm import softplus
|
||||
@@ -116,6 +117,34 @@ def _chunk_cumsum_fwd_kernel(
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# Small headdim/dstate configs (hdim<=64, dstate<=128) - increased parallelism
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
# Low register pressure configs for large dstate (dstate=128)
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
# original configs for larger headdim/dstate values
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
@@ -251,7 +280,7 @@ def _chunk_state_fwd_kernel(
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
|
||||
scale = fast_exp(dA_cs_last - dA_cs_k) * dt_k
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
@@ -273,238 +302,6 @@ def _chunk_state_fwd_kernel(
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=["hdim", "dstate", "chunk_size"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_varlen_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr,
|
||||
b_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
chunk_states_ptr,
|
||||
cu_seqlens_ptr,
|
||||
states_ptr,
|
||||
initstates_ptr,
|
||||
# Matrix dimensions
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_b_dstate: tl.constexpr,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_chunk_states_chunk: tl.int64,
|
||||
stride_chunk_states_head: tl.int64,
|
||||
stride_chunk_states_hdim: tl.int64,
|
||||
stride_chunk_states_dstate: tl.constexpr,
|
||||
stride_states_batch: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_init_states_batch: tl.int64,
|
||||
stride_init_states_head: tl.int64,
|
||||
stride_init_states_hdim: tl.int64,
|
||||
stride_init_states_dstate: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
||||
pid_c = (end_idx - 1) // chunk_size
|
||||
b_ptr += (
|
||||
pid_c * chunk_size * stride_b_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
)
|
||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
chunk_states_ptr += (
|
||||
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
||||
)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states provided, we differentiate between states (which
|
||||
# are boundary conditions at a chunk boundary) and initstates (which are boundary
|
||||
# conditions when a new example in a cont batch starts)
|
||||
initstates_ptr += pid_h * stride_init_states_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
||||
)
|
||||
b_ptrs = b_ptr + (
|
||||
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(
|
||||
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
||||
).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
chunk_size_limit = end_idx - pid_c * chunk_size
|
||||
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
||||
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim)
|
||||
& (offs_k[None, :] < chunk_size_limit - k)
|
||||
& (offs_k[None, :] >= start_idx_cur - k),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k)
|
||||
& (offs_n[None, :] < dstate)
|
||||
& (offs_k[:, None] >= start_idx_cur - k),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(
|
||||
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
||||
).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
scale = tl.where(
|
||||
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k,
|
||||
0.0,
|
||||
)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
||||
# If HAS_INITSTATES==True need to consider two possibilities
|
||||
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
|
||||
# - if state_idx >= pid * chunk_size, then we need to insert initstates
|
||||
if (
|
||||
(start_idx < pid_c * chunk_size) # first chunk
|
||||
or (HAS_INITSTATES)
|
||||
):
|
||||
dA_cs_boundary = 0.0 # default
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim
|
||||
+ offs_n[None, :] * stride_chunk_states_dstate
|
||||
)
|
||||
else:
|
||||
# - this seems repetitive, buts its to help the compiler
|
||||
if start_idx < pid_c * chunk_size:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim
|
||||
+ offs_n[None, :] * stride_chunk_states_dstate
|
||||
)
|
||||
else:
|
||||
past_states_ptrs = initstates_ptr + (
|
||||
pid_b * stride_init_states_batch
|
||||
+ offs_m[:, None] * stride_init_states_hdim
|
||||
+ offs_n[None, :] * stride_init_states_dstate
|
||||
)
|
||||
|
||||
# need to adjust the boundary
|
||||
if start_idx > pid_c * chunk_size:
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cumsum_ptr
|
||||
+ (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
||||
).to(tl.float32)
|
||||
|
||||
past_states = tl.load(
|
||||
past_states_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
scale = tl.exp(dA_cs_last - dA_cs_boundary)
|
||||
acc += past_states * scale
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (
|
||||
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
||||
)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
def _chunk_cumsum_fwd(
|
||||
dt,
|
||||
A,
|
||||
@@ -612,89 +409,3 @@ def _chunk_state_fwd(
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
)
|
||||
return states
|
||||
|
||||
|
||||
def chunk_state_varlen(
|
||||
B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None
|
||||
):
|
||||
total_seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
batch = cu_seqlens.shape[0] - 1
|
||||
cu_seqlens = cu_seqlens.contiguous()
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (total_seqlen, ngroups, dstate)
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
||||
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
|
||||
states = torch.empty(
|
||||
batch,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
dtype=chunk_states.dtype,
|
||||
device=chunk_states.device,
|
||||
)
|
||||
|
||||
initial_states_strides = (
|
||||
(
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3),
|
||||
)
|
||||
if initial_states is not None
|
||||
else (0, 0, 0, 0)
|
||||
)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
||||
batch,
|
||||
nheads,
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_varlen_kernel[grid](
|
||||
x_ptr=x,
|
||||
b_ptr=B,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
chunk_states_ptr=chunk_states,
|
||||
cu_seqlens_ptr=cu_seqlens,
|
||||
states_ptr=states,
|
||||
initstates_ptr=initial_states,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_b_seqlen=B.stride(0),
|
||||
stride_b_head=B.stride(1),
|
||||
stride_b_dstate=B.stride(2),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_chunk_states_chunk=chunk_states.stride(0),
|
||||
stride_chunk_states_head=chunk_states.stride(1),
|
||||
stride_chunk_states_hdim=chunk_states.stride(2),
|
||||
stride_chunk_states_dstate=chunk_states.stride(3),
|
||||
stride_states_batch=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_init_states_batch=initial_states_strides[0],
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return states
|
||||
|
||||
@@ -107,18 +107,15 @@ def _mamba_chunk_scan_combined_fwd(
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
# - for handling chunked prefill, this requires i) initial_states and
|
||||
# ii) seq_idx to be all specified.
|
||||
# - When a new seq_idx is detected, we will stop passing the prev_state
|
||||
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
||||
# - parallelized across sequences using last_chunk_indices to derive
|
||||
# per-sequence chunk ranges. Each sequence's state passing runs independently.
|
||||
states = _state_passing_fwd(
|
||||
rearrange(states, "... p n -> ... (p n)"),
|
||||
dA_cumsum, # (nheads, nchunks, chunk_size)
|
||||
cu_chunk_seqlens,
|
||||
last_chunk_indices,
|
||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
||||
if initial_states is not None
|
||||
else None, # (batch, nheads, headdim*dstate)
|
||||
seq_idx=seq_idx,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
)
|
||||
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@@ -29,12 +30,9 @@ def _state_passing_fwd_kernel(
|
||||
out_ptr,
|
||||
dA_cs_ptr,
|
||||
initstates_ptr,
|
||||
seq_idx_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
last_chunk_indices_ptr,
|
||||
# Matrix dimensions
|
||||
dim: tl.constexpr,
|
||||
nchunks,
|
||||
seqlen,
|
||||
chunk_size: tl.constexpr,
|
||||
# Strides
|
||||
stride_states_chunk: tl.int64,
|
||||
@@ -49,55 +47,51 @@ def _state_passing_fwd_kernel(
|
||||
stride_initstates_batch: tl.int64,
|
||||
stride_initstates_head: tl.int64,
|
||||
stride_initstates_dim: tl.constexpr,
|
||||
stride_seq_idx_chunk: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid_h = tl.program_id(axis=1)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
|
||||
states_ptr += pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize
|
||||
out_ptr += pid_h * stride_out_head
|
||||
# Derive this sequence's chunk range from last_chunk_indices
|
||||
chunk_end = tl.load(last_chunk_indices_ptr + pid_b) + 1
|
||||
chunk_start = (
|
||||
tl.load(last_chunk_indices_ptr + pid_b - 1, mask=pid_b > 0, other=-1) + 1
|
||||
)
|
||||
|
||||
# Offset pointers to this sequence's first chunk
|
||||
states_ptr += chunk_start * stride_states_chunk + pid_h * stride_states_head
|
||||
dA_cs_ptr += (
|
||||
pid_h * stride_dA_cs_head
|
||||
+ chunk_start * stride_dA_cs_chunk
|
||||
+ (chunk_size - 1) * stride_dA_cs_csize
|
||||
)
|
||||
out_ptr += chunk_start * stride_out_chunk + pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
states_ptrs = states_ptr + offs_m * stride_states_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
|
||||
# Load initial state once — no per-chunk branching needed
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptrs = (
|
||||
initstates_ptr
|
||||
+ pid_b * stride_initstates_batch
|
||||
+ pid_h * stride_initstates_head
|
||||
+ offs_m * stride_initstates_dim
|
||||
)
|
||||
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
else:
|
||||
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
prev_seq_idx = 0
|
||||
for c in range(nchunks):
|
||||
# Loop over only this sequence's chunks — branchless
|
||||
nchunks_this_seq = chunk_end - chunk_start
|
||||
for _ in range(nchunks_this_seq):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk)
|
||||
# we have started a new sequence
|
||||
if prev_seq_idx != seq_idx:
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptrs = (
|
||||
initstates_ptr
|
||||
+ seq_idx * stride_initstates_batch
|
||||
+ pid_h * stride_initstates_head
|
||||
+ offs_m * stride_initstates_dim
|
||||
)
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
else:
|
||||
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
prev_seq_idx = seq_idx
|
||||
states = tl.exp(dA_cs) * states + new_states
|
||||
states = fast_exp(dA_cs) * states + new_states
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
|
||||
states_ptrs += stride_states_chunk
|
||||
@@ -108,15 +102,14 @@ def _state_passing_fwd_kernel(
|
||||
def _state_passing_fwd(
|
||||
states,
|
||||
dA_cumsum,
|
||||
cu_chunk_seqlens,
|
||||
seq_idx,
|
||||
last_chunk_indices,
|
||||
initial_states=None,
|
||||
out_dtype=None,
|
||||
):
|
||||
nchunks, nheads, dim = states.shape
|
||||
chunk_size = dA_cumsum.shape[-1]
|
||||
batch = last_chunk_indices.shape[0]
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
seqlen = seq_idx.shape[-1]
|
||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||
out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype)
|
||||
|
||||
@@ -126,19 +119,16 @@ def _state_passing_fwd(
|
||||
else (0, 0, 0)
|
||||
)
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads)
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads)
|
||||
with torch.cuda.device(states.device.index):
|
||||
_state_passing_fwd_kernel[grid](
|
||||
states_ptr=states,
|
||||
out_ptr=out,
|
||||
dA_cs_ptr=dA_cumsum,
|
||||
initstates_ptr=initial_states,
|
||||
seq_idx_ptr=seq_idx,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
last_chunk_indices_ptr=last_chunk_indices,
|
||||
dim=dim,
|
||||
nchunks=nchunks,
|
||||
seqlen=seqlen if seq_idx is not None else 0,
|
||||
chunk_size=chunk_size if seq_idx is not None else 0,
|
||||
chunk_size=chunk_size,
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_dim=states.stride(2),
|
||||
@@ -151,7 +141,6 @@ def _state_passing_fwd(
|
||||
stride_initstates_batch=initial_states_strides[0],
|
||||
stride_initstates_head=initial_states_strides[1],
|
||||
stride_initstates_dim=initial_states_strides[2],
|
||||
stride_seq_idx_chunk=seq_idx.stride(0),
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return out
|
||||
|
||||
17
vllm/model_executor/layers/mamba/ops/triton_helpers.py
Normal file
17
vllm/model_executor/layers/mamba/ops/triton_helpers.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fast_exp(x):
|
||||
"""Faster alternative to tl.exp() using the hardware exp2 instruction.
|
||||
|
||||
tl.math.exp2 maps directly to a single ex2.approx.f32 PTX instruction,
|
||||
while tl.exp goes through libdevice __nv_expf which adds function call
|
||||
overhead and extra range checking.
|
||||
"""
|
||||
# exp(x) = exp2(x * log2(e)), where log2(e) = 1/ln(2) = 1.4426950408889634
|
||||
LOG2E = tl.constexpr(1.4426950408889634)
|
||||
return tl.math.exp2(LOG2E * x)
|
||||
Reference in New Issue
Block a user