[Mamba1] - Kernel Level Chunk Alignment for Prefix Caching (#34798)
Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
struct SSMParamsBase {
|
||||
using index_t = size_t;
|
||||
|
||||
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
||||
int batch, dim, seqlen, dstate, n_groups;
|
||||
int dim_ngroups_ratio;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
@@ -72,6 +72,8 @@ struct SSMParamsBase {
|
||||
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
|
||||
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
|
||||
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
|
||||
void *__restrict__ cu_chunk_seqlen_ptr; // (nchunks+1,) - cumulative chunk token offsets
|
||||
void *__restrict__ last_chunk_indices_ptr; // (batch,) - index of last chunk per sequence
|
||||
};
|
||||
|
||||
|
||||
|
||||
@@ -81,7 +81,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
||||
constexpr bool kHasZ = Ktraits::kHasZ;
|
||||
constexpr bool kVarlen = Ktraits::kVarlen;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
constexpr int kNRows = Ktraits::kNRows;
|
||||
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
||||
@@ -161,17 +160,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
||||
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
||||
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
||||
// }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNItems;
|
||||
|
||||
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
|
||||
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
|
||||
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
|
||||
const int block_size = params.cache_enabled ? params.block_size : 2048;
|
||||
|
||||
const int* batch_cache_indices = cache_indices != nullptr ?
|
||||
cache_indices + batch_id * params.cache_indices_stride : nullptr;
|
||||
@@ -181,10 +171,44 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
|
||||
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
|
||||
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
|
||||
const int* cu_chunk_seqlen = params.cu_chunk_seqlen_ptr != nullptr ?
|
||||
reinterpret_cast<const int*>(params.cu_chunk_seqlen_ptr) : nullptr;
|
||||
const int* last_chunk_indices = params.last_chunk_indices_ptr != nullptr ?
|
||||
reinterpret_cast<const int*>(params.last_chunk_indices_ptr) : nullptr;
|
||||
|
||||
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
|
||||
|
||||
const int block_idx_first = (params.cache_enabled && block_idx_first_scheduled != nullptr) ?
|
||||
block_idx_first_scheduled[batch_id] : 0;
|
||||
|
||||
// Determine chunk boundaries from pre-computed metadata (APC mode)
|
||||
// or fall back to simple block_size chunking.
|
||||
int first_chunk_idx, n_chunks;
|
||||
int current_position;
|
||||
|
||||
if (cu_chunk_seqlen != nullptr && last_chunk_indices != nullptr) {
|
||||
const int last_chunk_idx = last_chunk_indices[batch_id];
|
||||
first_chunk_idx = (batch_id == 0) ? 0 : last_chunk_indices[batch_id - 1] + 1;
|
||||
n_chunks = last_chunk_idx - first_chunk_idx + 1;
|
||||
// Derive current_position: if the first chunk is partial (fills remainder
|
||||
// of a started block), offset into the block accordingly.
|
||||
const int first_chunk_tokens = cu_chunk_seqlen[first_chunk_idx + 1] - cu_chunk_seqlen[first_chunk_idx];
|
||||
const int chunk_start_offset = (n_chunks > 1 && first_chunk_tokens < block_size)
|
||||
? (block_size - first_chunk_tokens) : 0;
|
||||
current_position = block_idx_first * block_size + chunk_start_offset;
|
||||
} else {
|
||||
first_chunk_idx = 0;
|
||||
n_chunks = (seqlen + block_size - 1) / block_size;
|
||||
current_position = 0;
|
||||
}
|
||||
|
||||
int tokens_processed = 0;
|
||||
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
const int chunk_tokens = (cu_chunk_seqlen != nullptr)
|
||||
? cu_chunk_seqlen[first_chunk_idx + chunk + 1] - cu_chunk_seqlen[first_chunk_idx + chunk]
|
||||
: min(block_size, seqlen - tokens_processed);
|
||||
if (chunk_tokens <= 0) break;
|
||||
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
||||
|
||||
__syncthreads();
|
||||
@@ -193,12 +217,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize);
|
||||
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, chunk_tokens);
|
||||
if constexpr (!kDirectIO) { __syncthreads(); }
|
||||
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize);
|
||||
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, chunk_tokens);
|
||||
}
|
||||
u += kChunkSize;
|
||||
delta += kChunkSize;
|
||||
u += chunk_tokens;
|
||||
delta += chunk_tokens;
|
||||
|
||||
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
||||
#pragma unroll
|
||||
@@ -232,7 +256,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
weight_t B_vals[kNItems], C_vals[kNItems];
|
||||
if constexpr (kIsVariableB) {
|
||||
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
||||
smem_load_weight, (seqlen - chunk * kChunkSize) * (1));
|
||||
smem_load_weight, chunk_tokens);
|
||||
if constexpr (!kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
@@ -243,7 +267,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
if constexpr (kIsVariableC) {
|
||||
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
||||
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
||||
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
|
||||
smem_load_weight_C, chunk_tokens);
|
||||
if constexpr (!kIsVariableB) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
@@ -266,10 +290,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
||||
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
||||
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
|
||||
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
|
||||
thread_data[i] = make_float2(1.f, 0.f);
|
||||
}
|
||||
if (threadIdx.x * kNItems + i >= chunk_tokens) {
|
||||
thread_data[i] = make_float2(1.f, 0.f);
|
||||
}
|
||||
}
|
||||
// Initialize running total
|
||||
@@ -301,14 +323,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
if (threadIdx.x == 0) {
|
||||
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
|
||||
|
||||
// Store state at the end of each chunk when cache is enabled
|
||||
// Store state at the end of each aligned chunk when cache is enabled
|
||||
if (params.cache_enabled && batch_cache_indices != nullptr) {
|
||||
|
||||
size_t cache_slot;
|
||||
if (chunk == n_chunks - 1) {
|
||||
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
|
||||
} else {
|
||||
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
|
||||
const int block_idx_completed = (current_position + chunk_tokens - 1) / block_size;
|
||||
cache_slot = batch_cache_indices[block_idx_completed];
|
||||
}
|
||||
|
||||
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
|
||||
@@ -331,38 +353,41 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
}
|
||||
}
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
||||
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
||||
+ dim_id * kNRows * params.out_d_stride + tokens_processed;
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
|
||||
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, chunk_tokens);
|
||||
}
|
||||
|
||||
if constexpr (kHasZ) {
|
||||
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
|
||||
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
||||
+ dim_id * kNRows * params.z_d_stride + tokens_processed;
|
||||
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
|
||||
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
||||
+ dim_id * kNRows * params.out_z_d_stride + tokens_processed;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
input_t z_vals[kNItems];
|
||||
__syncthreads();
|
||||
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize);
|
||||
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, chunk_tokens);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float z_val = z_vals[i];
|
||||
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
||||
}
|
||||
__syncthreads();
|
||||
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
|
||||
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, chunk_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
Bvar += kChunkSize * 1;
|
||||
Cvar += kChunkSize * 1;
|
||||
Bvar += chunk_tokens;
|
||||
Cvar += chunk_tokens;
|
||||
|
||||
tokens_processed += chunk_tokens;
|
||||
current_position += chunk_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -506,7 +531,9 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
const std::optional<torch::Tensor> &initial_state_idx) {
|
||||
const std::optional<torch::Tensor> &initial_state_idx,
|
||||
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
||||
const std::optional<torch::Tensor> &last_chunk_indices) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
@@ -548,6 +575,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
|
||||
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
|
||||
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
|
||||
params.cu_chunk_seqlen_ptr = cu_chunk_seqlen.has_value() ? cu_chunk_seqlen.value().data_ptr() : nullptr;
|
||||
params.last_chunk_indices_ptr = last_chunk_indices.has_value() ? last_chunk_indices.value().data_ptr() : nullptr;
|
||||
|
||||
// All stride are in elements, not bytes.
|
||||
params.A_d_stride = A.stride(0);
|
||||
@@ -633,7 +662,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
const std::optional<torch::Tensor> &initial_state_idx) {
|
||||
const std::optional<torch::Tensor> &initial_state_idx,
|
||||
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
||||
const std::optional<torch::Tensor> &last_chunk_indices) {
|
||||
auto input_type = u.scalar_type();
|
||||
auto weight_type = A.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
@@ -778,7 +809,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx
|
||||
initial_state_idx,
|
||||
cu_chunk_seqlen,
|
||||
last_chunk_indices
|
||||
);
|
||||
|
||||
|
||||
|
||||
@@ -371,7 +371,9 @@ void selective_scan_fwd(
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
|
||||
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
||||
const std::optional<torch::Tensor>& initial_state_idx);
|
||||
const std::optional<torch::Tensor>& initial_state_idx,
|
||||
const std::optional<torch::Tensor>& cu_chunk_seqlen,
|
||||
const std::optional<torch::Tensor>& last_chunk_indices);
|
||||
|
||||
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||
|
||||
@@ -640,7 +640,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"int block_size,"
|
||||
"Tensor? block_idx_first_scheduled_token,"
|
||||
"Tensor? block_idx_last_scheduled_token,"
|
||||
"Tensor? initial_state_idx) -> ()");
|
||||
"Tensor? initial_state_idx,"
|
||||
"Tensor? cu_chunk_seqlen,"
|
||||
"Tensor? last_chunk_indices) -> ()");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
// Hadamard transforms
|
||||
|
||||
@@ -183,6 +183,8 @@ def selective_scan_opcheck_fn(
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
cu_chunk_seqlen=None,
|
||||
last_chunk_indices=None,
|
||||
):
|
||||
"""if return_last_state is True, returns (out, last_state)
|
||||
last_state has shape (batch, dim, dstate).
|
||||
@@ -231,6 +233,8 @@ def selective_scan_opcheck_fn(
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
cu_chunk_seqlen,
|
||||
last_chunk_indices,
|
||||
),
|
||||
test_utils=["test_schema", "test_faketensor"],
|
||||
)
|
||||
|
||||
@@ -2021,6 +2021,8 @@ def selective_scan_fwd(
|
||||
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||
initial_state_idx: torch.Tensor | None = None,
|
||||
cu_chunk_seqlen: torch.Tensor | None = None,
|
||||
last_chunk_indices: torch.Tensor | None = None,
|
||||
):
|
||||
torch.ops._C.selective_scan_fwd(
|
||||
u,
|
||||
@@ -2041,6 +2043,8 @@ def selective_scan_fwd(
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
cu_chunk_seqlen,
|
||||
last_chunk_indices,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
|
||||
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||
initial_state_idx=block_idx_last_computed_token_p,
|
||||
cu_chunk_seqlen=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
)
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
|
||||
@@ -497,6 +497,8 @@ def selective_scan_fn(
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
cu_chunk_seqlen=None,
|
||||
last_chunk_indices=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
@@ -588,6 +590,8 @@ def selective_scan_fn(
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
cu_chunk_seqlen,
|
||||
last_chunk_indices,
|
||||
)
|
||||
|
||||
if z is None:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any
|
||||
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
@@ -29,3 +30,31 @@ class Mamba1AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||
):
|
||||
metadata_cls = Mamba1AttentionMetadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Mamba1AttentionMetadata:
|
||||
common = self._compute_common_metadata(common_attn_metadata)
|
||||
|
||||
if (
|
||||
common.num_prefills > 0
|
||||
and self.vllm_config.cache_config.mamba_cache_mode == "all"
|
||||
):
|
||||
cu_chunk_seqlen_p, _, last_chunk_indices_p = (
|
||||
self._build_chunk_metadata_tensors(
|
||||
self.kv_cache_spec.block_size,
|
||||
common,
|
||||
common_attn_metadata,
|
||||
)
|
||||
)
|
||||
return replace(
|
||||
common,
|
||||
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||
last_chunk_indices_p=last_chunk_indices_p,
|
||||
)
|
||||
|
||||
return common
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
CommonAttentionMetadata,
|
||||
@@ -105,14 +104,6 @@ class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
|
||||
|
||||
# Chunk-related metadata (only for prefill)
|
||||
seq_idx_p: torch.Tensor | None = None
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None = None
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
@@ -134,68 +125,6 @@ class Mamba2AttentionMetadataBuilder(
|
||||
)
|
||||
self.chunk_size: int = chunk_size
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
num_prefills: int,
|
||||
num_computed_tokens_p_cpu: torch.Tensor,
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Compute chunk-specific metadata for Mamba2.
|
||||
|
||||
The code below carefully constructs the chunks such that:
|
||||
1. Chunks contain tokens from a *single* sequence only.
|
||||
2. For every sequence, we are guaranteed that we can
|
||||
retrieve the mamba state *every* chunk_size tokens.
|
||||
Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||
Constraint (2) dramatically simplifies the implementation
|
||||
of prefix caching for mamba2 (wip). We need to take care
|
||||
of the interaction with chunked prefill in order to
|
||||
satisfy constraint (2).
|
||||
"""
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % self.chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
||||
- this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
@@ -220,41 +149,12 @@ class Mamba2AttentionMetadataBuilder(
|
||||
else False
|
||||
)
|
||||
|
||||
num_reqs = common.num_reqs
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||
num_prefills,
|
||||
num_computed_tokens_p_cpu,
|
||||
query_start_loc_p_cpu,
|
||||
)
|
||||
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p = (
|
||||
self._build_chunk_metadata_tensors(
|
||||
self.chunk_size,
|
||||
common,
|
||||
common_attn_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return replace(
|
||||
|
||||
@@ -59,6 +59,15 @@ class BaseMambaAttentionMetadata:
|
||||
# The following tensor is only used for prefix caching in align mode
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None = None
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
@@ -185,6 +194,118 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
|
||||
)
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
chunk_size: int,
|
||||
num_prefills: int,
|
||||
num_computed_tokens_p_cpu: torch.Tensor,
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Compute chunk-specific metadata for Mamba models.
|
||||
|
||||
The code below carefully constructs the chunks such that:
|
||||
1. Chunks contain tokens from a *single* sequence only.
|
||||
2. For every sequence, we are guaranteed that we can
|
||||
retrieve the mamba state *every* chunk_size tokens.
|
||||
Constraint (1) dramatically simplifies the mamba kernels.
|
||||
Constraint (2) dramatically simplifies the implementation
|
||||
of prefix caching for mamba (wip). We need to take care
|
||||
of the interaction with chunked prefill in order to
|
||||
satisfy constraint (2).
|
||||
"""
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||
|
||||
def _build_chunk_metadata_tensors(
|
||||
self,
|
||||
chunk_size: int,
|
||||
common: M,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute chunk metadata and return as device tensors.
|
||||
Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p).
|
||||
"""
|
||||
num_reqs = common.num_reqs
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||
chunk_size,
|
||||
num_prefills,
|
||||
num_computed_tokens_p_cpu,
|
||||
query_start_loc_p_cpu,
|
||||
)
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p
|
||||
|
||||
def _compute_prefix_caching_block_indices(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
Reference in New Issue
Block a user