From bbf81f9a9284d572b69db2c4fb002c2a8a80d507 Mon Sep 17 00:00:00 2001 From: Asaf Gardin <39553475+Josephasafg@users.noreply.github.com> Date: Sun, 1 Mar 2026 14:40:23 +0200 Subject: [PATCH] [Mamba1] - Kernel Level Chunk Alignment for Prefix Caching (#34798) Signed-off-by: Josephasafg --- csrc/mamba/mamba_ssm/selective_scan.h | 4 +- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 103 ++++++++++----- csrc/ops.h | 4 +- csrc/torch_bindings.cpp | 4 +- tests/kernels/mamba/test_mamba_ssm.py | 4 + vllm/_custom_ops.py | 4 + .../layers/mamba/mamba_mixer.py | 4 + .../layers/mamba/ops/mamba_ssm.py | 4 + vllm/v1/attention/backends/mamba1_attn.py | 33 ++++- vllm/v1/attention/backends/mamba2_attn.py | 112 +--------------- vllm/v1/attention/backends/mamba_attn.py | 121 ++++++++++++++++++ 11 files changed, 251 insertions(+), 146 deletions(-) diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index e93455a57..8f33c7cfa 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -17,7 +17,7 @@ struct SSMParamsBase { using index_t = size_t; - int batch, dim, seqlen, dstate, n_groups, n_chunks; + int batch, dim, seqlen, dstate, n_groups; int dim_ngroups_ratio; bool is_variable_B; bool is_variable_C; @@ -72,6 +72,8 @@ struct SSMParamsBase { void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use + void *__restrict__ cu_chunk_seqlen_ptr; // (nchunks+1,) - cumulative chunk token offsets + void *__restrict__ last_chunk_indices_ptr; // (batch,) - index of last chunk per sequence }; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index fb2a2e578..d852a0ed4 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -81,7 +81,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { constexpr bool kIsVariableC = Ktraits::kIsVariableC; constexpr bool kHasZ = Ktraits::kHasZ; constexpr bool kVarlen = Ktraits::kVarlen; - constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNItems = Ktraits::kNItems; constexpr int kNRows = Ktraits::kNRows; constexpr bool kDirectIO = Ktraits::kDirectIO; @@ -161,17 +160,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } - - // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; - // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; - // } - - constexpr int kChunkSize = kNThreads * kNItems; - // Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility - const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048; - const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size; + const int block_size = params.cache_enabled ? params.block_size : 2048; const int* batch_cache_indices = cache_indices != nullptr ? cache_indices + batch_id * params.cache_indices_stride : nullptr; @@ -181,10 +171,44 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { reinterpret_cast(params.block_idx_last_scheduled_token_ptr) : nullptr; const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ? reinterpret_cast(params.initial_state_idx_ptr) : nullptr; + const int* cu_chunk_seqlen = params.cu_chunk_seqlen_ptr != nullptr ? + reinterpret_cast(params.cu_chunk_seqlen_ptr) : nullptr; + const int* last_chunk_indices = params.last_chunk_indices_ptr != nullptr ? + reinterpret_cast(params.last_chunk_indices_ptr) : nullptr; const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index; + const int block_idx_first = (params.cache_enabled && block_idx_first_scheduled != nullptr) ? + block_idx_first_scheduled[batch_id] : 0; + + // Determine chunk boundaries from pre-computed metadata (APC mode) + // or fall back to simple block_size chunking. + int first_chunk_idx, n_chunks; + int current_position; + + if (cu_chunk_seqlen != nullptr && last_chunk_indices != nullptr) { + const int last_chunk_idx = last_chunk_indices[batch_id]; + first_chunk_idx = (batch_id == 0) ? 0 : last_chunk_indices[batch_id - 1] + 1; + n_chunks = last_chunk_idx - first_chunk_idx + 1; + // Derive current_position: if the first chunk is partial (fills remainder + // of a started block), offset into the block accordingly. + const int first_chunk_tokens = cu_chunk_seqlen[first_chunk_idx + 1] - cu_chunk_seqlen[first_chunk_idx]; + const int chunk_start_offset = (n_chunks > 1 && first_chunk_tokens < block_size) + ? (block_size - first_chunk_tokens) : 0; + current_position = block_idx_first * block_size + chunk_start_offset; + } else { + first_chunk_idx = 0; + n_chunks = (seqlen + block_size - 1) / block_size; + current_position = 0; + } + + int tokens_processed = 0; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + const int chunk_tokens = (cu_chunk_seqlen != nullptr) + ? cu_chunk_seqlen[first_chunk_idx + chunk + 1] - cu_chunk_seqlen[first_chunk_idx + chunk] + : min(block_size, seqlen - tokens_processed); + if (chunk_tokens <= 0) break; input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; __syncthreads(); @@ -193,12 +217,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (!kDirectIO) { if (r > 0) { __syncthreads(); } } - load_input(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize); + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, chunk_tokens); if constexpr (!kDirectIO) { __syncthreads(); } - load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, chunk_tokens); } - u += kChunkSize; - delta += kChunkSize; + u += chunk_tokens; + delta += chunk_tokens; float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; #pragma unroll @@ -232,7 +256,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { weight_t B_vals[kNItems], C_vals[kNItems]; if constexpr (kIsVariableB) { load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (seqlen - chunk * kChunkSize) * (1)); + smem_load_weight, chunk_tokens); if constexpr (!kIsVariableC) { #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -243,7 +267,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (kIsVariableC) { auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1)); + smem_load_weight_C, chunk_tokens); if constexpr (!kIsVariableB) { #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -266,10 +290,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { for (int i = 0; i < kNItems; ++i) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) { - thread_data[i] = make_float2(1.f, 0.f); - } + if (threadIdx.x * kNItems + i >= chunk_tokens) { + thread_data[i] = make_float2(1.f, 0.f); } } // Initialize running total @@ -301,14 +323,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (threadIdx.x == 0) { smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix; - // Store state at the end of each chunk when cache is enabled + // Store state at the end of each aligned chunk when cache is enabled if (params.cache_enabled && batch_cache_indices != nullptr) { - size_t cache_slot; if (chunk == n_chunks - 1) { cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]]; } else { - cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk]; + const int block_idx_completed = (current_position + chunk_tokens - 1) / block_size; + cache_slot = batch_cache_indices[block_idx_completed]; } size_t state_offset = cache_slot * params.ssm_states_batch_stride + @@ -331,38 +353,41 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride - + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + + dim_id * kNRows * params.out_d_stride + tokens_processed; __syncthreads(); #pragma unroll for (int r = 0; r < kNRows; ++r) { if constexpr (!kDirectIO) { if (r > 0) { __syncthreads(); } } - store_output(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, chunk_tokens); } if constexpr (kHasZ) { input_t *z = reinterpret_cast(params.z_ptr) + sequence_start_index * params.z_batch_stride - + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + + dim_id * kNRows * params.z_d_stride + tokens_processed; input_t *out_z = reinterpret_cast(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride - + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + + dim_id * kNRows * params.out_z_d_stride + tokens_processed; #pragma unroll for (int r = 0; r < kNRows; ++r) { input_t z_vals[kNItems]; __syncthreads(); - load_input(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize); + load_input(z + r * params.z_d_stride, z_vals, smem_load, chunk_tokens); #pragma unroll for (int i = 0; i < kNItems; ++i) { float z_val = z_vals[i]; out_vals[r][i] *= z_val / (1 + expf(-z_val)); } __syncthreads(); - store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, chunk_tokens); } } - Bvar += kChunkSize * 1; - Cvar += kChunkSize * 1; + Bvar += chunk_tokens; + Cvar += chunk_tokens; + + tokens_processed += chunk_tokens; + current_position += chunk_tokens; } } @@ -506,7 +531,9 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, int64_t block_size, const std::optional &block_idx_first_scheduled_token, const std::optional &block_idx_last_scheduled_token, - const std::optional &initial_state_idx) { + const std::optional &initial_state_idx, + const std::optional &cu_chunk_seqlen, + const std::optional &last_chunk_indices) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -548,6 +575,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr; params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr; params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr; + params.cu_chunk_seqlen_ptr = cu_chunk_seqlen.has_value() ? cu_chunk_seqlen.value().data_ptr() : nullptr; + params.last_chunk_indices_ptr = last_chunk_indices.has_value() ? last_chunk_indices.value().data_ptr() : nullptr; // All stride are in elements, not bytes. params.A_d_stride = A.stride(0); @@ -633,7 +662,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, int64_t block_size, const std::optional &block_idx_first_scheduled_token, const std::optional &block_idx_last_scheduled_token, - const std::optional &initial_state_idx) { + const std::optional &initial_state_idx, + const std::optional &cu_chunk_seqlen, + const std::optional &last_chunk_indices) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -778,7 +809,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, block_size, block_idx_first_scheduled_token, block_idx_last_scheduled_token, - initial_state_idx + initial_state_idx, + cu_chunk_seqlen, + last_chunk_indices ); diff --git a/csrc/ops.h b/csrc/ops.h index 690342b37..921d6484d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -371,7 +371,9 @@ void selective_scan_fwd( const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size, const std::optional& block_idx_first_scheduled_token, const std::optional& block_idx_last_scheduled_token, - const std::optional& initial_state_idx); + const std::optional& initial_state_idx, + const std::optional& cu_chunk_seqlen, + const std::optional& last_chunk_indices); torch::Tensor dynamic_4bit_int_moe_cpu( torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8be30b209..9ba18289e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -640,7 +640,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int block_size," "Tensor? block_idx_first_scheduled_token," "Tensor? block_idx_last_scheduled_token," - "Tensor? initial_state_idx) -> ()"); + "Tensor? initial_state_idx," + "Tensor? cu_chunk_seqlen," + "Tensor? last_chunk_indices) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); // Hadamard transforms diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 905207109..9a00e1d04 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -183,6 +183,8 @@ def selective_scan_opcheck_fn( block_idx_first_scheduled_token=None, block_idx_last_scheduled_token=None, initial_state_idx=None, + cu_chunk_seqlen=None, + last_chunk_indices=None, ): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). @@ -231,6 +233,8 @@ def selective_scan_opcheck_fn( block_idx_first_scheduled_token, block_idx_last_scheduled_token, initial_state_idx, + cu_chunk_seqlen, + last_chunk_indices, ), test_utils=["test_schema", "test_faketensor"], ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 46f9dfad9..9ed8dfa8d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2021,6 +2021,8 @@ def selective_scan_fwd( block_idx_first_scheduled_token: torch.Tensor | None = None, block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, + cu_chunk_seqlen: torch.Tensor | None = None, + last_chunk_indices: torch.Tensor | None = None, ): torch.ops._C.selective_scan_fwd( u, @@ -2041,6 +2043,8 @@ def selective_scan_fwd( block_idx_first_scheduled_token, block_idx_last_scheduled_token, initial_state_idx, + cu_chunk_seqlen, + last_chunk_indices, ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 24e189a5c..6a33fc7d6 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p + last_chunk_indices_p = attn_metadata.last_chunk_indices_p # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer): block_idx_first_scheduled_token=block_idx_first_scheduled_token_p, block_idx_last_scheduled_token=block_idx_last_scheduled_token_p, initial_state_idx=block_idx_last_computed_token_p, + cu_chunk_seqlen=cu_chunk_seqlen_p, + last_chunk_indices=last_chunk_indices_p, ) ssm_outputs.append(scan_out_p) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index a0df65f90..44e73dd20 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -497,6 +497,8 @@ def selective_scan_fn( block_idx_first_scheduled_token=None, block_idx_last_scheduled_token=None, initial_state_idx=None, + cu_chunk_seqlen=None, + last_chunk_indices=None, ) -> torch.Tensor: """ u: (dim, total_length) for varlen or (batch, dim, seqlen) @@ -588,6 +590,8 @@ def selective_scan_fn( block_idx_first_scheduled_token, block_idx_last_scheduled_token, initial_state_idx, + cu_chunk_seqlen, + last_chunk_indices, ) if z is None: diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index c7228ecea..890340620 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass +from dataclasses import dataclass, replace +from typing import Any -from vllm.v1.attention.backend import AttentionBackend +from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata from vllm.v1.attention.backends.mamba_attn import ( BaseMambaAttentionMetadata, BaseMambaAttentionMetadataBuilder, @@ -29,3 +30,31 @@ class Mamba1AttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata] ): metadata_cls = Mamba1AttentionMetadata + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + **kwargs: Any, + ) -> Mamba1AttentionMetadata: + common = self._compute_common_metadata(common_attn_metadata) + + if ( + common.num_prefills > 0 + and self.vllm_config.cache_config.mamba_cache_mode == "all" + ): + cu_chunk_seqlen_p, _, last_chunk_indices_p = ( + self._build_chunk_metadata_tensors( + self.kv_cache_spec.block_size, + common, + common_attn_metadata, + ) + ) + return replace( + common, + cu_chunk_seqlen_p=cu_chunk_seqlen_p, + last_chunk_indices_p=last_chunk_indices_p, + ) + + return common diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 94587c3d6..5e8abbab5 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -7,7 +7,6 @@ from typing import Any import torch from vllm.config import VllmConfig -from vllm.utils.math_utils import cdiv from vllm.v1.attention.backend import ( AttentionBackend, CommonAttentionMetadata, @@ -105,14 +104,6 @@ class Mamba2AttentionMetadata(BaseMambaAttentionMetadata): # Chunk-related metadata (only for prefill) seq_idx_p: torch.Tensor | None = None - # cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for - # each chunk, its offsets into the varlen sequence dimension. It is defined - # such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to - # cu_chunk_seqlen_p[i+1]. - cu_chunk_seqlen_p: torch.Tensor | None = None - # last_chunk_indices_p is a tensor of shape (batch,) that contains the - # index of the last chunk for every sequence in the (prefill) batch. - last_chunk_indices_p: torch.Tensor | None = None class Mamba2AttentionMetadataBuilder( @@ -134,68 +125,6 @@ class Mamba2AttentionMetadataBuilder( ) self.chunk_size: int = chunk_size - def _compute_chunk_metadata( - self, - num_prefills: int, - num_computed_tokens_p_cpu: torch.Tensor, - query_start_loc_p_cpu: torch.Tensor, - ) -> tuple[list[int], list[int], list[int]]: - """ - Compute chunk-specific metadata for Mamba2. - - The code below carefully constructs the chunks such that: - 1. Chunks contain tokens from a *single* sequence only. - 2. For every sequence, we are guaranteed that we can - retrieve the mamba state *every* chunk_size tokens. - Constraint (1) dramatically simplifies the mamba2 kernels. - Constraint (2) dramatically simplifies the implementation - of prefix caching for mamba2 (wip). We need to take care - of the interaction with chunked prefill in order to - satisfy constraint (2). - """ - # TODO (tdoublep): This code could probably be optimized. - cu_chunk_seqlen = [] - seq_idx = [] - last_chunk_indices = [] - seqlen_pos = 0 - - for req_idx in range(num_prefills): - this_num_computed = num_computed_tokens_p_cpu[req_idx].item() - this_new_tokens = ( - query_start_loc_p_cpu[req_idx + 1].item() - - query_start_loc_p_cpu[req_idx].item() - ) - - # if computed tokens are not chunk-aligned, use the first - # chunk to finish it off - if this_num_computed % self.chunk_size != 0: - seq_idx.append(req_idx) - cu_chunk_seqlen.append(seqlen_pos) - # how many tokens to finish the chunk? - chunk_len = ( - cdiv(this_num_computed, self.chunk_size) * self.chunk_size - - this_num_computed - ) - # we can only use at most this_new_tokens - chunk_len = min(chunk_len, this_new_tokens) - seqlen_pos += chunk_len - this_new_tokens -= chunk_len - - n_chunks = cdiv(this_new_tokens, self.chunk_size) - for chunk in range(n_chunks): - seq_idx.append(req_idx) - cu_chunk_seqlen.append(seqlen_pos) - chunk_len = min(self.chunk_size, this_new_tokens) - seqlen_pos += chunk_len - this_new_tokens -= chunk_len - - assert this_new_tokens == 0 - last_chunk_indices.append(len(cu_chunk_seqlen) - 1) - - cu_chunk_seqlen.append(seqlen_pos) - - return cu_chunk_seqlen, seq_idx, last_chunk_indices - def build( self, common_prefix_len: int, @@ -220,41 +149,12 @@ class Mamba2AttentionMetadataBuilder( else False ) - num_reqs = common.num_reqs - num_prefills = common.num_prefills - num_decode_tokens = common.num_decode_tokens - - num_computed_tokens_cpu = ( - common_attn_metadata.compute_num_computed_tokens().cpu() - ) - num_computed_tokens_p_cpu = num_computed_tokens_cpu[ - num_reqs - num_prefills : num_reqs - ] - query_start_loc_p_cpu = ( - common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :] - - num_decode_tokens - ) - - cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata( - num_prefills, - num_computed_tokens_p_cpu, - query_start_loc_p_cpu, - ) - - seq_idx_p = torch.as_tensor( - seq_idx, - device=common_attn_metadata.query_start_loc.device, - dtype=torch.int32, - ) - cu_chunk_seqlen_p = torch.as_tensor( - cu_chunk_seqlen, - device=common_attn_metadata.query_start_loc.device, - dtype=torch.int32, - ) - last_chunk_indices_p = torch.as_tensor( - last_chunk_indices, - device=common_attn_metadata.query_start_loc.device, - dtype=torch.int32, + cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p = ( + self._build_chunk_metadata_tensors( + self.chunk_size, + common, + common_attn_metadata, + ) ) return replace( diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index c4ffb16f5..27c9b85eb 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -59,6 +59,15 @@ class BaseMambaAttentionMetadata: # The following tensor is only used for prefix caching in align mode seq_lens: torch.Tensor + # cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for + # each chunk, its offsets into the varlen sequence dimension. It is defined + # such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to + # cu_chunk_seqlen_p[i+1]. + cu_chunk_seqlen_p: torch.Tensor | None = None + # last_chunk_indices_p is a tensor of shape (batch,) that contains the + # index of the last chunk for every sequence in the (prefill) batch. + last_chunk_indices_p: torch.Tensor | None = None + # The following attributes are for triton implementation of causal_conv1d nums_dict: dict | None = None batch_ptr: torch.Tensor | None = None @@ -185,6 +194,118 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): common_attn_metadata, num_accepted_tokens=num_accepted_tokens ) + def _compute_chunk_metadata( + self, + chunk_size: int, + num_prefills: int, + num_computed_tokens_p_cpu: torch.Tensor, + query_start_loc_p_cpu: torch.Tensor, + ) -> tuple[list[int], list[int], list[int]]: + """ + Compute chunk-specific metadata for Mamba models. + + The code below carefully constructs the chunks such that: + 1. Chunks contain tokens from a *single* sequence only. + 2. For every sequence, we are guaranteed that we can + retrieve the mamba state *every* chunk_size tokens. + Constraint (1) dramatically simplifies the mamba kernels. + Constraint (2) dramatically simplifies the implementation + of prefix caching for mamba (wip). We need to take care + of the interaction with chunked prefill in order to + satisfy constraint (2). + """ + # TODO (tdoublep): This code could probably be optimized. + cu_chunk_seqlen = [] + seq_idx = [] + last_chunk_indices = [] + seqlen_pos = 0 + + for req_idx in range(num_prefills): + this_num_computed = num_computed_tokens_p_cpu[req_idx].item() + this_new_tokens = ( + query_start_loc_p_cpu[req_idx + 1].item() + - query_start_loc_p_cpu[req_idx].item() + ) + + # if computed tokens are not chunk-aligned, use the first + # chunk to finish it off + if this_num_computed % chunk_size != 0: + seq_idx.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + # how many tokens to finish the chunk? + chunk_len = ( + cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed + ) + # we can only use at most this_new_tokens + chunk_len = min(chunk_len, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + n_chunks = cdiv(this_new_tokens, chunk_size) + for chunk in range(n_chunks): + seq_idx.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + chunk_len = min(chunk_size, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + assert this_new_tokens == 0 + last_chunk_indices.append(len(cu_chunk_seqlen) - 1) + + cu_chunk_seqlen.append(seqlen_pos) + + return cu_chunk_seqlen, seq_idx, last_chunk_indices + + def _build_chunk_metadata_tensors( + self, + chunk_size: int, + common: M, + common_attn_metadata: CommonAttentionMetadata, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute chunk metadata and return as device tensors. + Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p). + """ + num_reqs = common.num_reqs + num_prefills = common.num_prefills + num_decode_tokens = common.num_decode_tokens + + num_computed_tokens_cpu = ( + common_attn_metadata.compute_num_computed_tokens().cpu() + ) + num_computed_tokens_p_cpu = num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + query_start_loc_p_cpu = ( + common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :] + - num_decode_tokens + ) + + cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata( + chunk_size, + num_prefills, + num_computed_tokens_p_cpu, + query_start_loc_p_cpu, + ) + + device = common_attn_metadata.query_start_loc.device + cu_chunk_seqlen_p = torch.as_tensor( + cu_chunk_seqlen, + device=device, + dtype=torch.int32, + ) + seq_idx_p = torch.as_tensor( + seq_idx, + device=device, + dtype=torch.int32, + ) + last_chunk_indices_p = torch.as_tensor( + last_chunk_indices, + device=device, + dtype=torch.int32, + ) + return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p + def _compute_prefix_caching_block_indices( self, common_attn_metadata: CommonAttentionMetadata,