diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 44e73dd20..50778a990 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -8,6 +8,7 @@ import torch from packaging import version from vllm import _custom_ops as ops +from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.attention.backends.utils import PAD_SLOT_ID @@ -215,7 +216,7 @@ def _selective_scan_update_kernel( mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0, ).to(tl.float32) - dA = tl.exp(A * dt[:, None]) + dA = fast_exp(A * dt[:, None]) else: dt = tl.load(dt_ptr).to(tl.float32) if HAS_DT_BIAS: @@ -223,7 +224,7 @@ def _selective_scan_update_kernel( if DT_SOFTPLUS: dt = softplus(dt) A = tl.load(A_ptr).to(tl.float32) - dA = tl.exp(A * dt) # scalar, not a matrix + dA = fast_exp(A * dt) # scalar, not a matrix B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 661c88462..8057a8d32 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -8,6 +8,7 @@ from packaging import version +from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp from vllm.triton_utils import tl, triton TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") @@ -15,6 +16,76 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") @triton.autotune( configs=[ + # ================================================================= + # Higher warp count configs for better latency hiding + # More warps = more instructions in flight = better memory latency hiding + # ================================================================= + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=2, + num_warps=8, + ), + # Smaller tiles with more stages for software pipelining + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}, + num_stages=2, + num_warps=4, + ), + # ================================================================= + # Low register pressure configs (num_stages=1) for large dstate + # ================================================================= + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=1, + num_warps=4, + ), + # num_stages=2 configs - moderate register pressure + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=2, + num_warps=4, + ), + # Original configs for larger dstate values triton.Config( {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, @@ -200,7 +271,7 @@ def _chunk_scan_fwd_kernel( offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate ) - scale_m = tl.exp(dA_cs_m) + scale_m = fast_exp(dA_cs_m) if BLOCK_SIZE_DSTATE <= 128: C = tl.load( C_ptrs, @@ -285,7 +356,7 @@ def _chunk_scan_fwd_kernel( ) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. - cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + cb *= fast_exp(dA_cs_m[:, None] - dA_cs_k[None, :]) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 11cc125bf..ed60593f5 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -8,6 +8,7 @@ import torch +from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp from vllm.triton_utils import tl, triton from .mamba_ssm import softplus @@ -116,6 +117,34 @@ def _chunk_cumsum_fwd_kernel( @triton.autotune( configs=[ + # Small headdim/dstate configs (hdim<=64, dstate<=128) - increased parallelism + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=3, + num_warps=4, + ), + # Low register pressure configs for large dstate (dstate=128) + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, + num_stages=2, + num_warps=4, + ), + # original configs for larger headdim/dstate values triton.Config( {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, @@ -251,7 +280,7 @@ def _chunk_state_fwd_kernel( dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( tl.float32 ) - scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + scale = fast_exp(dA_cs_last - dA_cs_k) * dt_k b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -273,238 +302,6 @@ def _chunk_state_fwd_kernel( tl.store(states_ptrs, states, mask=c_mask) -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=2, - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=2, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=2, - ), - ], - key=["hdim", "dstate", "chunk_size"], -) -@triton.jit -def _chunk_state_varlen_kernel( - # Pointers to matrices - x_ptr, - b_ptr, - dt_ptr, - dA_cumsum_ptr, - chunk_states_ptr, - cu_seqlens_ptr, - states_ptr, - initstates_ptr, - # Matrix dimensions - hdim: tl.constexpr, - dstate: tl.constexpr, - chunk_size: tl.constexpr, - nheads_ngroups_ratio: tl.constexpr, - # Strides - stride_x_seqlen: tl.int64, - stride_x_head: tl.int64, - stride_x_hdim: tl.constexpr, - stride_b_seqlen: tl.int64, - stride_b_head: tl.int64, - stride_b_dstate: tl.constexpr, - stride_dt_head: tl.int64, - stride_dt_chunk: tl.int64, - stride_dt_csize: tl.constexpr, - stride_dA_cs_head: tl.int64, - stride_dA_cs_chunk: tl.int64, - stride_dA_cs_csize: tl.constexpr, - stride_chunk_states_chunk: tl.int64, - stride_chunk_states_head: tl.int64, - stride_chunk_states_hdim: tl.int64, - stride_chunk_states_dstate: tl.constexpr, - stride_states_batch: tl.int64, - stride_states_head: tl.int64, - stride_states_hdim: tl.int64, - stride_states_dstate: tl.constexpr, - stride_init_states_batch: tl.int64, - stride_init_states_head: tl.int64, - stride_init_states_hdim: tl.int64, - stride_init_states_dstate: tl.constexpr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - HAS_INITSTATES: tl.constexpr, -): - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) - pid_c = (end_idx - 1) // chunk_size - b_ptr += ( - pid_c * chunk_size * stride_b_seqlen - + (pid_h // nheads_ngroups_ratio) * stride_b_head - ) - x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - chunk_states_ptr += ( - pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head - ) - - if HAS_INITSTATES: - # if there are init states provided, we differentiate between states (which - # are boundary conditions at a chunk boundary) and initstates (which are boundary - # conditions when a new example in a cont batch starts) - initstates_ptr += pid_h * stride_init_states_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + ( - offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen - ) - b_ptrs = b_ptr + ( - offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen - ) - dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load( - dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize - ).to(tl.float32) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - - chunk_size_limit = end_idx - pid_c * chunk_size - start_idx = tl.load(cu_seqlens_ptr + pid_b) - start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load( - x_ptrs, - mask=(offs_m[:, None] < hdim) - & (offs_k[None, :] < chunk_size_limit - k) - & (offs_k[None, :] >= start_idx_cur - k), - other=0.0, - ) - b = tl.load( - b_ptrs, - mask=(offs_k[:, None] < chunk_size_limit - k) - & (offs_n[None, :] < dstate) - & (offs_k[:, None] >= start_idx_cur - k), - other=0.0, - ).to(tl.float32) - dA_cs_k = tl.load( - dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 - ).to(tl.float32) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( - tl.float32 - ) - scale = tl.where( - (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), - tl.exp(dA_cs_last - dA_cs_k) * dt_k, - 0.0, - ) - b *= scale[:, None] - b = b.to(x_ptr.dtype.element_ty) - acc += tl.dot(x, b) - x_ptrs += BLOCK_SIZE_K * stride_x_seqlen - b_ptrs += BLOCK_SIZE_K * stride_b_seqlen - dt_ptrs += BLOCK_SIZE_K * stride_dt_csize - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - - # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk - # If HAS_INITSTATES==True need to consider two possibilities - # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs - # - if state_idx >= pid * chunk_size, then we need to insert initstates - if ( - (start_idx < pid_c * chunk_size) # first chunk - or (HAS_INITSTATES) - ): - dA_cs_boundary = 0.0 # default - - if not HAS_INITSTATES: - past_states_ptrs = chunk_states_ptr + ( - offs_m[:, None] * stride_chunk_states_hdim - + offs_n[None, :] * stride_chunk_states_dstate - ) - else: - # - this seems repetitive, buts its to help the compiler - if start_idx < pid_c * chunk_size: - past_states_ptrs = chunk_states_ptr + ( - offs_m[:, None] * stride_chunk_states_hdim - + offs_n[None, :] * stride_chunk_states_dstate - ) - else: - past_states_ptrs = initstates_ptr + ( - pid_b * stride_init_states_batch - + offs_m[:, None] * stride_init_states_hdim - + offs_n[None, :] * stride_init_states_dstate - ) - - # need to adjust the boundary - if start_idx > pid_c * chunk_size: - dA_cs_boundary = tl.load( - dA_cumsum_ptr - + (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize - ).to(tl.float32) - - past_states = tl.load( - past_states_ptrs, - mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), - other=0.0, - ).to(tl.float32) - - scale = tl.exp(dA_cs_last - dA_cs_boundary) - acc += past_states * scale - - states = acc.to(states_ptr.dtype.element_ty) - - states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + ( - offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate - ) - c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) - tl.store(states_ptrs, states, mask=c_mask) - - def _chunk_cumsum_fwd( dt, A, @@ -612,89 +409,3 @@ def _chunk_state_fwd( stride_dA_cs_csize=dA_cumsum.stride(2), ) return states - - -def chunk_state_varlen( - B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None -): - total_seqlen, nheads, headdim = x.shape - _, nchunks, chunk_size = dt.shape - _, ngroups, dstate = B.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - assert nheads % ngroups == 0 - assert B.shape == (total_seqlen, ngroups, dstate) - assert dt.shape == (nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert chunk_states.shape == (nchunks, nheads, headdim, dstate) - - if initial_states is not None: - assert initial_states.shape == (batch, nheads, headdim, dstate) - - states = torch.empty( - batch, - nheads, - headdim, - dstate, - dtype=chunk_states.dtype, - device=chunk_states.device, - ) - - initial_states_strides = ( - ( - initial_states.stride(0), - initial_states.stride(1), - initial_states.stride(2), - initial_states.stride(3), - ) - if initial_states is not None - else (0, 0, 0, 0) - ) - - grid = lambda META: ( - triton.cdiv(headdim, META["BLOCK_SIZE_M"]) - * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), - batch, - nheads, - ) - with torch.cuda.device(x.device.index): - _chunk_state_varlen_kernel[grid]( - x_ptr=x, - b_ptr=B, - dt_ptr=dt, - dA_cumsum_ptr=dA_cumsum, - chunk_states_ptr=chunk_states, - cu_seqlens_ptr=cu_seqlens, - states_ptr=states, - initstates_ptr=initial_states, - hdim=headdim, - dstate=dstate, - chunk_size=chunk_size, - nheads_ngroups_ratio=nheads // ngroups, - stride_x_seqlen=x.stride(0), - stride_x_head=x.stride(1), - stride_x_hdim=x.stride(2), - stride_b_seqlen=B.stride(0), - stride_b_head=B.stride(1), - stride_b_dstate=B.stride(2), - stride_dt_head=dt.stride(0), - stride_dt_chunk=dt.stride(1), - stride_dt_csize=dt.stride(2), - stride_dA_cs_head=dA_cumsum.stride(0), - stride_dA_cs_chunk=dA_cumsum.stride(1), - stride_dA_cs_csize=dA_cumsum.stride(2), - stride_chunk_states_chunk=chunk_states.stride(0), - stride_chunk_states_head=chunk_states.stride(1), - stride_chunk_states_hdim=chunk_states.stride(2), - stride_chunk_states_dstate=chunk_states.stride(3), - stride_states_batch=states.stride(0), - stride_states_head=states.stride(1), - stride_states_hdim=states.stride(2), - stride_states_dstate=states.stride(3), - stride_init_states_batch=initial_states_strides[0], - stride_init_states_head=initial_states_strides[1], - stride_init_states_hdim=initial_states_strides[2], - stride_init_states_dstate=initial_states_strides[3], - HAS_INITSTATES=initial_states is not None, - ) - return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index ac905ada7..4c93a768b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -107,18 +107,15 @@ def _mamba_chunk_scan_combined_fwd( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - # - for handling chunked prefill, this requires i) initial_states and - # ii) seq_idx to be all specified. - # - When a new seq_idx is detected, we will stop passing the prev_state - # and switch accordingly to the init_state corresponding to the new seq_idx. + # - parallelized across sequences using last_chunk_indices to derive + # per-sequence chunk ranges. Each sequence's state passing runs independently. states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum, # (nheads, nchunks, chunk_size) - cu_chunk_seqlens, + last_chunk_indices, initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, # (batch, nheads, headdim*dstate) - seq_idx=seq_idx, out_dtype=state_dtype if state_dtype is not None else C.dtype, ) states = rearrange(states, "... (p n) -> ... p n", n=dstate) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 5481bab17..5c5cb9d37 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -8,6 +8,7 @@ import torch +from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp from vllm.triton_utils import tl, triton @@ -29,12 +30,9 @@ def _state_passing_fwd_kernel( out_ptr, dA_cs_ptr, initstates_ptr, - seq_idx_ptr, - cu_chunk_seqlens_ptr, + last_chunk_indices_ptr, # Matrix dimensions dim: tl.constexpr, - nchunks, - seqlen, chunk_size: tl.constexpr, # Strides stride_states_chunk: tl.int64, @@ -49,55 +47,51 @@ def _state_passing_fwd_kernel( stride_initstates_batch: tl.int64, stride_initstates_head: tl.int64, stride_initstates_dim: tl.constexpr, - stride_seq_idx_chunk: tl.constexpr, # Meta-parameters HAS_INITSTATES: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - pid_h = tl.program_id(axis=1) pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) - states_ptr += pid_h * stride_states_head - dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize - out_ptr += pid_h * stride_out_head + # Derive this sequence's chunk range from last_chunk_indices + chunk_end = tl.load(last_chunk_indices_ptr + pid_b) + 1 + chunk_start = ( + tl.load(last_chunk_indices_ptr + pid_b - 1, mask=pid_b > 0, other=-1) + 1 + ) + + # Offset pointers to this sequence's first chunk + states_ptr += chunk_start * stride_states_chunk + pid_h * stride_states_head + dA_cs_ptr += ( + pid_h * stride_dA_cs_head + + chunk_start * stride_dA_cs_chunk + + (chunk_size - 1) * stride_dA_cs_csize + ) + out_ptr += chunk_start * stride_out_chunk + pid_h * stride_out_head offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) states_ptrs = states_ptr + offs_m * stride_states_dim out_ptrs = out_ptr + offs_m * stride_out_dim + # Load initial state once — no per-chunk branching needed if HAS_INITSTATES: initstates_ptrs = ( initstates_ptr + + pid_b * stride_initstates_batch + pid_h * stride_initstates_head + offs_m * stride_initstates_dim ) - states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) else: states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - prev_seq_idx = 0 - for c in range(nchunks): + # Loop over only this sequence's chunks — branchless + nchunks_this_seq = chunk_end - chunk_start + for _ in range(nchunks_this_seq): new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk) - # we have started a new sequence - if prev_seq_idx != seq_idx: - if HAS_INITSTATES: - initstates_ptrs = ( - initstates_ptr - + seq_idx * stride_initstates_batch - + pid_h * stride_initstates_head - + offs_m * stride_initstates_dim - ) - states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to( - tl.float32 - ) - else: - states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - - prev_seq_idx = seq_idx - states = tl.exp(dA_cs) * states + new_states + states = fast_exp(dA_cs) * states + new_states tl.store(out_ptrs, states, mask=offs_m < dim) states_ptrs += stride_states_chunk @@ -108,15 +102,14 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, dA_cumsum, - cu_chunk_seqlens, - seq_idx, + last_chunk_indices, initial_states=None, out_dtype=None, ): nchunks, nheads, dim = states.shape chunk_size = dA_cumsum.shape[-1] + batch = last_chunk_indices.shape[0] assert dA_cumsum.shape == (nheads, nchunks, chunk_size) - seqlen = seq_idx.shape[-1] out_dtype = states.dtype if out_dtype is None else out_dtype out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype) @@ -126,19 +119,16 @@ def _state_passing_fwd( else (0, 0, 0) ) - grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads) + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( states_ptr=states, out_ptr=out, dA_cs_ptr=dA_cumsum, initstates_ptr=initial_states, - seq_idx_ptr=seq_idx, - cu_chunk_seqlens_ptr=cu_chunk_seqlens, + last_chunk_indices_ptr=last_chunk_indices, dim=dim, - nchunks=nchunks, - seqlen=seqlen if seq_idx is not None else 0, - chunk_size=chunk_size if seq_idx is not None else 0, + chunk_size=chunk_size, stride_states_chunk=states.stride(0), stride_states_head=states.stride(1), stride_states_dim=states.stride(2), @@ -151,7 +141,6 @@ def _state_passing_fwd( stride_initstates_batch=initial_states_strides[0], stride_initstates_head=initial_states_strides[1], stride_initstates_dim=initial_states_strides[2], - stride_seq_idx_chunk=seq_idx.stride(0), HAS_INITSTATES=initial_states is not None, ) return out diff --git a/vllm/model_executor/layers/mamba/ops/triton_helpers.py b/vllm/model_executor/layers/mamba/ops/triton_helpers.py new file mode 100644 index 000000000..186cb27bd --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/triton_helpers.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.triton_utils import tl, triton + + +@triton.jit +def fast_exp(x): + """Faster alternative to tl.exp() using the hardware exp2 instruction. + + tl.math.exp2 maps directly to a single ex2.approx.f32 PTX instruction, + while tl.exp goes through libdevice __nv_expf which adds function call + overhead and extra range checking. + """ + # exp(x) = exp2(x * log2(e)), where log2(e) = 1/ln(2) = 1.4426950408889634 + LOG2E = tl.constexpr(1.4426950408889634) + return tl.math.exp2(LOG2E * x)