[Mamba1] - Kernel Level Chunk Alignment for Prefix Caching (#34798)
Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
@@ -17,7 +17,7 @@
|
|||||||
struct SSMParamsBase {
|
struct SSMParamsBase {
|
||||||
using index_t = size_t;
|
using index_t = size_t;
|
||||||
|
|
||||||
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
int batch, dim, seqlen, dstate, n_groups;
|
||||||
int dim_ngroups_ratio;
|
int dim_ngroups_ratio;
|
||||||
bool is_variable_B;
|
bool is_variable_B;
|
||||||
bool is_variable_C;
|
bool is_variable_C;
|
||||||
@@ -72,6 +72,8 @@ struct SSMParamsBase {
|
|||||||
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
|
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
|
||||||
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
|
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
|
||||||
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
|
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
|
||||||
|
void *__restrict__ cu_chunk_seqlen_ptr; // (nchunks+1,) - cumulative chunk token offsets
|
||||||
|
void *__restrict__ last_chunk_indices_ptr; // (batch,) - index of last chunk per sequence
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -81,7 +81,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
||||||
constexpr bool kHasZ = Ktraits::kHasZ;
|
constexpr bool kHasZ = Ktraits::kHasZ;
|
||||||
constexpr bool kVarlen = Ktraits::kVarlen;
|
constexpr bool kVarlen = Ktraits::kVarlen;
|
||||||
constexpr int kNThreads = Ktraits::kNThreads;
|
|
||||||
constexpr int kNItems = Ktraits::kNItems;
|
constexpr int kNItems = Ktraits::kNItems;
|
||||||
constexpr int kNRows = Ktraits::kNRows;
|
constexpr int kNRows = Ktraits::kNRows;
|
||||||
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
||||||
@@ -161,17 +160,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
|
||||||
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
|
||||||
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
|
||||||
// }
|
|
||||||
|
|
||||||
constexpr int kChunkSize = kNThreads * kNItems;
|
|
||||||
|
|
||||||
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
|
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
|
||||||
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
|
const int block_size = params.cache_enabled ? params.block_size : 2048;
|
||||||
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
|
|
||||||
|
|
||||||
const int* batch_cache_indices = cache_indices != nullptr ?
|
const int* batch_cache_indices = cache_indices != nullptr ?
|
||||||
cache_indices + batch_id * params.cache_indices_stride : nullptr;
|
cache_indices + batch_id * params.cache_indices_stride : nullptr;
|
||||||
@@ -181,10 +171,44 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
|
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
|
||||||
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
|
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
|
||||||
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
|
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
|
||||||
|
const int* cu_chunk_seqlen = params.cu_chunk_seqlen_ptr != nullptr ?
|
||||||
|
reinterpret_cast<const int*>(params.cu_chunk_seqlen_ptr) : nullptr;
|
||||||
|
const int* last_chunk_indices = params.last_chunk_indices_ptr != nullptr ?
|
||||||
|
reinterpret_cast<const int*>(params.last_chunk_indices_ptr) : nullptr;
|
||||||
|
|
||||||
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
|
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
|
||||||
|
|
||||||
|
const int block_idx_first = (params.cache_enabled && block_idx_first_scheduled != nullptr) ?
|
||||||
|
block_idx_first_scheduled[batch_id] : 0;
|
||||||
|
|
||||||
|
// Determine chunk boundaries from pre-computed metadata (APC mode)
|
||||||
|
// or fall back to simple block_size chunking.
|
||||||
|
int first_chunk_idx, n_chunks;
|
||||||
|
int current_position;
|
||||||
|
|
||||||
|
if (cu_chunk_seqlen != nullptr && last_chunk_indices != nullptr) {
|
||||||
|
const int last_chunk_idx = last_chunk_indices[batch_id];
|
||||||
|
first_chunk_idx = (batch_id == 0) ? 0 : last_chunk_indices[batch_id - 1] + 1;
|
||||||
|
n_chunks = last_chunk_idx - first_chunk_idx + 1;
|
||||||
|
// Derive current_position: if the first chunk is partial (fills remainder
|
||||||
|
// of a started block), offset into the block accordingly.
|
||||||
|
const int first_chunk_tokens = cu_chunk_seqlen[first_chunk_idx + 1] - cu_chunk_seqlen[first_chunk_idx];
|
||||||
|
const int chunk_start_offset = (n_chunks > 1 && first_chunk_tokens < block_size)
|
||||||
|
? (block_size - first_chunk_tokens) : 0;
|
||||||
|
current_position = block_idx_first * block_size + chunk_start_offset;
|
||||||
|
} else {
|
||||||
|
first_chunk_idx = 0;
|
||||||
|
n_chunks = (seqlen + block_size - 1) / block_size;
|
||||||
|
current_position = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int tokens_processed = 0;
|
||||||
|
|
||||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||||
|
const int chunk_tokens = (cu_chunk_seqlen != nullptr)
|
||||||
|
? cu_chunk_seqlen[first_chunk_idx + chunk + 1] - cu_chunk_seqlen[first_chunk_idx + chunk]
|
||||||
|
: min(block_size, seqlen - tokens_processed);
|
||||||
|
if (chunk_tokens <= 0) break;
|
||||||
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@@ -193,12 +217,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
if constexpr (!kDirectIO) {
|
if constexpr (!kDirectIO) {
|
||||||
if (r > 0) { __syncthreads(); }
|
if (r > 0) { __syncthreads(); }
|
||||||
}
|
}
|
||||||
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize);
|
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, chunk_tokens);
|
||||||
if constexpr (!kDirectIO) { __syncthreads(); }
|
if constexpr (!kDirectIO) { __syncthreads(); }
|
||||||
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize);
|
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, chunk_tokens);
|
||||||
}
|
}
|
||||||
u += kChunkSize;
|
u += chunk_tokens;
|
||||||
delta += kChunkSize;
|
delta += chunk_tokens;
|
||||||
|
|
||||||
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -232,7 +256,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
weight_t B_vals[kNItems], C_vals[kNItems];
|
weight_t B_vals[kNItems], C_vals[kNItems];
|
||||||
if constexpr (kIsVariableB) {
|
if constexpr (kIsVariableB) {
|
||||||
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
||||||
smem_load_weight, (seqlen - chunk * kChunkSize) * (1));
|
smem_load_weight, chunk_tokens);
|
||||||
if constexpr (!kIsVariableC) {
|
if constexpr (!kIsVariableC) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int r = 0; r < kNRows; ++r) {
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
@@ -243,7 +267,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
if constexpr (kIsVariableC) {
|
if constexpr (kIsVariableC) {
|
||||||
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
||||||
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
||||||
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
|
smem_load_weight_C, chunk_tokens);
|
||||||
if constexpr (!kIsVariableB) {
|
if constexpr (!kIsVariableB) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int r = 0; r < kNRows; ++r) {
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
@@ -266,10 +290,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
for (int i = 0; i < kNItems; ++i) {
|
for (int i = 0; i < kNItems; ++i) {
|
||||||
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
||||||
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
||||||
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
|
if (threadIdx.x * kNItems + i >= chunk_tokens) {
|
||||||
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
|
thread_data[i] = make_float2(1.f, 0.f);
|
||||||
thread_data[i] = make_float2(1.f, 0.f);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Initialize running total
|
// Initialize running total
|
||||||
@@ -301,14 +323,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
|
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
|
||||||
|
|
||||||
// Store state at the end of each chunk when cache is enabled
|
// Store state at the end of each aligned chunk when cache is enabled
|
||||||
if (params.cache_enabled && batch_cache_indices != nullptr) {
|
if (params.cache_enabled && batch_cache_indices != nullptr) {
|
||||||
|
|
||||||
size_t cache_slot;
|
size_t cache_slot;
|
||||||
if (chunk == n_chunks - 1) {
|
if (chunk == n_chunks - 1) {
|
||||||
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
|
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
|
||||||
} else {
|
} else {
|
||||||
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
|
const int block_idx_completed = (current_position + chunk_tokens - 1) / block_size;
|
||||||
|
cache_slot = batch_cache_indices[block_idx_completed];
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
|
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
|
||||||
@@ -331,38 +353,41 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
||||||
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
+ dim_id * kNRows * params.out_d_stride + tokens_processed;
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int r = 0; r < kNRows; ++r) {
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
if constexpr (!kDirectIO) {
|
if constexpr (!kDirectIO) {
|
||||||
if (r > 0) { __syncthreads(); }
|
if (r > 0) { __syncthreads(); }
|
||||||
}
|
}
|
||||||
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
|
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, chunk_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (kHasZ) {
|
if constexpr (kHasZ) {
|
||||||
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
|
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
|
||||||
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
+ dim_id * kNRows * params.z_d_stride + tokens_processed;
|
||||||
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
|
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
|
||||||
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
+ dim_id * kNRows * params.out_z_d_stride + tokens_processed;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int r = 0; r < kNRows; ++r) {
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
input_t z_vals[kNItems];
|
input_t z_vals[kNItems];
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize);
|
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, chunk_tokens);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < kNItems; ++i) {
|
for (int i = 0; i < kNItems; ++i) {
|
||||||
float z_val = z_vals[i];
|
float z_val = z_vals[i];
|
||||||
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
|
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, chunk_tokens);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Bvar += kChunkSize * 1;
|
Bvar += chunk_tokens;
|
||||||
Cvar += kChunkSize * 1;
|
Cvar += chunk_tokens;
|
||||||
|
|
||||||
|
tokens_processed += chunk_tokens;
|
||||||
|
current_position += chunk_tokens;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -506,7 +531,9 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
int64_t block_size,
|
int64_t block_size,
|
||||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &initial_state_idx) {
|
const std::optional<torch::Tensor> &initial_state_idx,
|
||||||
|
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
||||||
|
const std::optional<torch::Tensor> &last_chunk_indices) {
|
||||||
|
|
||||||
// Reset the parameters
|
// Reset the parameters
|
||||||
memset(¶ms, 0, sizeof(params));
|
memset(¶ms, 0, sizeof(params));
|
||||||
@@ -548,6 +575,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
|
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
|
||||||
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
|
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
|
||||||
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
|
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
|
||||||
|
params.cu_chunk_seqlen_ptr = cu_chunk_seqlen.has_value() ? cu_chunk_seqlen.value().data_ptr() : nullptr;
|
||||||
|
params.last_chunk_indices_ptr = last_chunk_indices.has_value() ? last_chunk_indices.value().data_ptr() : nullptr;
|
||||||
|
|
||||||
// All stride are in elements, not bytes.
|
// All stride are in elements, not bytes.
|
||||||
params.A_d_stride = A.stride(0);
|
params.A_d_stride = A.stride(0);
|
||||||
@@ -633,7 +662,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
int64_t block_size,
|
int64_t block_size,
|
||||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &initial_state_idx) {
|
const std::optional<torch::Tensor> &initial_state_idx,
|
||||||
|
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
||||||
|
const std::optional<torch::Tensor> &last_chunk_indices) {
|
||||||
auto input_type = u.scalar_type();
|
auto input_type = u.scalar_type();
|
||||||
auto weight_type = A.scalar_type();
|
auto weight_type = A.scalar_type();
|
||||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||||
@@ -778,7 +809,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
block_size,
|
block_size,
|
||||||
block_idx_first_scheduled_token,
|
block_idx_first_scheduled_token,
|
||||||
block_idx_last_scheduled_token,
|
block_idx_last_scheduled_token,
|
||||||
initial_state_idx
|
initial_state_idx,
|
||||||
|
cu_chunk_seqlen,
|
||||||
|
last_chunk_indices
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -371,7 +371,9 @@ void selective_scan_fwd(
|
|||||||
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
|
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
|
||||||
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
||||||
const std::optional<torch::Tensor>& initial_state_idx);
|
const std::optional<torch::Tensor>& initial_state_idx,
|
||||||
|
const std::optional<torch::Tensor>& cu_chunk_seqlen,
|
||||||
|
const std::optional<torch::Tensor>& last_chunk_indices);
|
||||||
|
|
||||||
torch::Tensor dynamic_4bit_int_moe_cpu(
|
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||||
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||||
|
|||||||
@@ -640,7 +640,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"int block_size,"
|
"int block_size,"
|
||||||
"Tensor? block_idx_first_scheduled_token,"
|
"Tensor? block_idx_first_scheduled_token,"
|
||||||
"Tensor? block_idx_last_scheduled_token,"
|
"Tensor? block_idx_last_scheduled_token,"
|
||||||
"Tensor? initial_state_idx) -> ()");
|
"Tensor? initial_state_idx,"
|
||||||
|
"Tensor? cu_chunk_seqlen,"
|
||||||
|
"Tensor? last_chunk_indices) -> ()");
|
||||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||||
|
|
||||||
// Hadamard transforms
|
// Hadamard transforms
|
||||||
|
|||||||
@@ -183,6 +183,8 @@ def selective_scan_opcheck_fn(
|
|||||||
block_idx_first_scheduled_token=None,
|
block_idx_first_scheduled_token=None,
|
||||||
block_idx_last_scheduled_token=None,
|
block_idx_last_scheduled_token=None,
|
||||||
initial_state_idx=None,
|
initial_state_idx=None,
|
||||||
|
cu_chunk_seqlen=None,
|
||||||
|
last_chunk_indices=None,
|
||||||
):
|
):
|
||||||
"""if return_last_state is True, returns (out, last_state)
|
"""if return_last_state is True, returns (out, last_state)
|
||||||
last_state has shape (batch, dim, dstate).
|
last_state has shape (batch, dim, dstate).
|
||||||
@@ -231,6 +233,8 @@ def selective_scan_opcheck_fn(
|
|||||||
block_idx_first_scheduled_token,
|
block_idx_first_scheduled_token,
|
||||||
block_idx_last_scheduled_token,
|
block_idx_last_scheduled_token,
|
||||||
initial_state_idx,
|
initial_state_idx,
|
||||||
|
cu_chunk_seqlen,
|
||||||
|
last_chunk_indices,
|
||||||
),
|
),
|
||||||
test_utils=["test_schema", "test_faketensor"],
|
test_utils=["test_schema", "test_faketensor"],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2021,6 +2021,8 @@ def selective_scan_fwd(
|
|||||||
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
||||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||||
initial_state_idx: torch.Tensor | None = None,
|
initial_state_idx: torch.Tensor | None = None,
|
||||||
|
cu_chunk_seqlen: torch.Tensor | None = None,
|
||||||
|
last_chunk_indices: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
torch.ops._C.selective_scan_fwd(
|
torch.ops._C.selective_scan_fwd(
|
||||||
u,
|
u,
|
||||||
@@ -2041,6 +2043,8 @@ def selective_scan_fwd(
|
|||||||
block_idx_first_scheduled_token,
|
block_idx_first_scheduled_token,
|
||||||
block_idx_last_scheduled_token,
|
block_idx_last_scheduled_token,
|
||||||
initial_state_idx,
|
initial_state_idx,
|
||||||
|
cu_chunk_seqlen,
|
||||||
|
last_chunk_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer):
|
|||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||||
|
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
|
||||||
|
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||||
@@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer):
|
|||||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||||
initial_state_idx=block_idx_last_computed_token_p,
|
initial_state_idx=block_idx_last_computed_token_p,
|
||||||
|
cu_chunk_seqlen=cu_chunk_seqlen_p,
|
||||||
|
last_chunk_indices=last_chunk_indices_p,
|
||||||
)
|
)
|
||||||
ssm_outputs.append(scan_out_p)
|
ssm_outputs.append(scan_out_p)
|
||||||
|
|
||||||
|
|||||||
@@ -497,6 +497,8 @@ def selective_scan_fn(
|
|||||||
block_idx_first_scheduled_token=None,
|
block_idx_first_scheduled_token=None,
|
||||||
block_idx_last_scheduled_token=None,
|
block_idx_last_scheduled_token=None,
|
||||||
initial_state_idx=None,
|
initial_state_idx=None,
|
||||||
|
cu_chunk_seqlen=None,
|
||||||
|
last_chunk_indices=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||||
@@ -588,6 +590,8 @@ def selective_scan_fn(
|
|||||||
block_idx_first_scheduled_token,
|
block_idx_first_scheduled_token,
|
||||||
block_idx_last_scheduled_token,
|
block_idx_last_scheduled_token,
|
||||||
initial_state_idx,
|
initial_state_idx,
|
||||||
|
cu_chunk_seqlen,
|
||||||
|
last_chunk_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
if z is None:
|
if z is None:
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from vllm.v1.attention.backend import AttentionBackend
|
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
|
||||||
from vllm.v1.attention.backends.mamba_attn import (
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
BaseMambaAttentionMetadata,
|
BaseMambaAttentionMetadata,
|
||||||
BaseMambaAttentionMetadataBuilder,
|
BaseMambaAttentionMetadataBuilder,
|
||||||
@@ -29,3 +30,31 @@ class Mamba1AttentionMetadataBuilder(
|
|||||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||||
):
|
):
|
||||||
metadata_cls = Mamba1AttentionMetadata
|
metadata_cls = Mamba1AttentionMetadata
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Mamba1AttentionMetadata:
|
||||||
|
common = self._compute_common_metadata(common_attn_metadata)
|
||||||
|
|
||||||
|
if (
|
||||||
|
common.num_prefills > 0
|
||||||
|
and self.vllm_config.cache_config.mamba_cache_mode == "all"
|
||||||
|
):
|
||||||
|
cu_chunk_seqlen_p, _, last_chunk_indices_p = (
|
||||||
|
self._build_chunk_metadata_tensors(
|
||||||
|
self.kv_cache_spec.block_size,
|
||||||
|
common,
|
||||||
|
common_attn_metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return replace(
|
||||||
|
common,
|
||||||
|
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||||
|
last_chunk_indices_p=last_chunk_indices_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
return common
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Any
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.utils.math_utils import cdiv
|
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
@@ -105,14 +104,6 @@ class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
|
|||||||
|
|
||||||
# Chunk-related metadata (only for prefill)
|
# Chunk-related metadata (only for prefill)
|
||||||
seq_idx_p: torch.Tensor | None = None
|
seq_idx_p: torch.Tensor | None = None
|
||||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
|
||||||
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
|
||||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
|
||||||
# cu_chunk_seqlen_p[i+1].
|
|
||||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
|
||||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
|
||||||
# index of the last chunk for every sequence in the (prefill) batch.
|
|
||||||
last_chunk_indices_p: torch.Tensor | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class Mamba2AttentionMetadataBuilder(
|
class Mamba2AttentionMetadataBuilder(
|
||||||
@@ -134,68 +125,6 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
)
|
)
|
||||||
self.chunk_size: int = chunk_size
|
self.chunk_size: int = chunk_size
|
||||||
|
|
||||||
def _compute_chunk_metadata(
|
|
||||||
self,
|
|
||||||
num_prefills: int,
|
|
||||||
num_computed_tokens_p_cpu: torch.Tensor,
|
|
||||||
query_start_loc_p_cpu: torch.Tensor,
|
|
||||||
) -> tuple[list[int], list[int], list[int]]:
|
|
||||||
"""
|
|
||||||
Compute chunk-specific metadata for Mamba2.
|
|
||||||
|
|
||||||
The code below carefully constructs the chunks such that:
|
|
||||||
1. Chunks contain tokens from a *single* sequence only.
|
|
||||||
2. For every sequence, we are guaranteed that we can
|
|
||||||
retrieve the mamba state *every* chunk_size tokens.
|
|
||||||
Constraint (1) dramatically simplifies the mamba2 kernels.
|
|
||||||
Constraint (2) dramatically simplifies the implementation
|
|
||||||
of prefix caching for mamba2 (wip). We need to take care
|
|
||||||
of the interaction with chunked prefill in order to
|
|
||||||
satisfy constraint (2).
|
|
||||||
"""
|
|
||||||
# TODO (tdoublep): This code could probably be optimized.
|
|
||||||
cu_chunk_seqlen = []
|
|
||||||
seq_idx = []
|
|
||||||
last_chunk_indices = []
|
|
||||||
seqlen_pos = 0
|
|
||||||
|
|
||||||
for req_idx in range(num_prefills):
|
|
||||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
|
||||||
this_new_tokens = (
|
|
||||||
query_start_loc_p_cpu[req_idx + 1].item()
|
|
||||||
- query_start_loc_p_cpu[req_idx].item()
|
|
||||||
)
|
|
||||||
|
|
||||||
# if computed tokens are not chunk-aligned, use the first
|
|
||||||
# chunk to finish it off
|
|
||||||
if this_num_computed % self.chunk_size != 0:
|
|
||||||
seq_idx.append(req_idx)
|
|
||||||
cu_chunk_seqlen.append(seqlen_pos)
|
|
||||||
# how many tokens to finish the chunk?
|
|
||||||
chunk_len = (
|
|
||||||
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
|
||||||
- this_num_computed
|
|
||||||
)
|
|
||||||
# we can only use at most this_new_tokens
|
|
||||||
chunk_len = min(chunk_len, this_new_tokens)
|
|
||||||
seqlen_pos += chunk_len
|
|
||||||
this_new_tokens -= chunk_len
|
|
||||||
|
|
||||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
|
||||||
for chunk in range(n_chunks):
|
|
||||||
seq_idx.append(req_idx)
|
|
||||||
cu_chunk_seqlen.append(seqlen_pos)
|
|
||||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
|
||||||
seqlen_pos += chunk_len
|
|
||||||
this_new_tokens -= chunk_len
|
|
||||||
|
|
||||||
assert this_new_tokens == 0
|
|
||||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
|
||||||
|
|
||||||
cu_chunk_seqlen.append(seqlen_pos)
|
|
||||||
|
|
||||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
@@ -220,41 +149,12 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
|
|
||||||
num_reqs = common.num_reqs
|
cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p = (
|
||||||
num_prefills = common.num_prefills
|
self._build_chunk_metadata_tensors(
|
||||||
num_decode_tokens = common.num_decode_tokens
|
self.chunk_size,
|
||||||
|
common,
|
||||||
num_computed_tokens_cpu = (
|
common_attn_metadata,
|
||||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
)
|
||||||
)
|
|
||||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
|
||||||
num_reqs - num_prefills : num_reqs
|
|
||||||
]
|
|
||||||
query_start_loc_p_cpu = (
|
|
||||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
|
||||||
- num_decode_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
|
||||||
num_prefills,
|
|
||||||
num_computed_tokens_p_cpu,
|
|
||||||
query_start_loc_p_cpu,
|
|
||||||
)
|
|
||||||
|
|
||||||
seq_idx_p = torch.as_tensor(
|
|
||||||
seq_idx,
|
|
||||||
device=common_attn_metadata.query_start_loc.device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_chunk_seqlen_p = torch.as_tensor(
|
|
||||||
cu_chunk_seqlen,
|
|
||||||
device=common_attn_metadata.query_start_loc.device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
last_chunk_indices_p = torch.as_tensor(
|
|
||||||
last_chunk_indices,
|
|
||||||
device=common_attn_metadata.query_start_loc.device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return replace(
|
return replace(
|
||||||
|
|||||||
@@ -59,6 +59,15 @@ class BaseMambaAttentionMetadata:
|
|||||||
# The following tensor is only used for prefix caching in align mode
|
# The following tensor is only used for prefix caching in align mode
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
|
|
||||||
|
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||||
|
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
||||||
|
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||||
|
# cu_chunk_seqlen_p[i+1].
|
||||||
|
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||||
|
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||||
|
# index of the last chunk for every sequence in the (prefill) batch.
|
||||||
|
last_chunk_indices_p: torch.Tensor | None = None
|
||||||
|
|
||||||
# The following attributes are for triton implementation of causal_conv1d
|
# The following attributes are for triton implementation of causal_conv1d
|
||||||
nums_dict: dict | None = None
|
nums_dict: dict | None = None
|
||||||
batch_ptr: torch.Tensor | None = None
|
batch_ptr: torch.Tensor | None = None
|
||||||
@@ -185,6 +194,118 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
|||||||
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
|
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _compute_chunk_metadata(
|
||||||
|
self,
|
||||||
|
chunk_size: int,
|
||||||
|
num_prefills: int,
|
||||||
|
num_computed_tokens_p_cpu: torch.Tensor,
|
||||||
|
query_start_loc_p_cpu: torch.Tensor,
|
||||||
|
) -> tuple[list[int], list[int], list[int]]:
|
||||||
|
"""
|
||||||
|
Compute chunk-specific metadata for Mamba models.
|
||||||
|
|
||||||
|
The code below carefully constructs the chunks such that:
|
||||||
|
1. Chunks contain tokens from a *single* sequence only.
|
||||||
|
2. For every sequence, we are guaranteed that we can
|
||||||
|
retrieve the mamba state *every* chunk_size tokens.
|
||||||
|
Constraint (1) dramatically simplifies the mamba kernels.
|
||||||
|
Constraint (2) dramatically simplifies the implementation
|
||||||
|
of prefix caching for mamba (wip). We need to take care
|
||||||
|
of the interaction with chunked prefill in order to
|
||||||
|
satisfy constraint (2).
|
||||||
|
"""
|
||||||
|
# TODO (tdoublep): This code could probably be optimized.
|
||||||
|
cu_chunk_seqlen = []
|
||||||
|
seq_idx = []
|
||||||
|
last_chunk_indices = []
|
||||||
|
seqlen_pos = 0
|
||||||
|
|
||||||
|
for req_idx in range(num_prefills):
|
||||||
|
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||||
|
this_new_tokens = (
|
||||||
|
query_start_loc_p_cpu[req_idx + 1].item()
|
||||||
|
- query_start_loc_p_cpu[req_idx].item()
|
||||||
|
)
|
||||||
|
|
||||||
|
# if computed tokens are not chunk-aligned, use the first
|
||||||
|
# chunk to finish it off
|
||||||
|
if this_num_computed % chunk_size != 0:
|
||||||
|
seq_idx.append(req_idx)
|
||||||
|
cu_chunk_seqlen.append(seqlen_pos)
|
||||||
|
# how many tokens to finish the chunk?
|
||||||
|
chunk_len = (
|
||||||
|
cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed
|
||||||
|
)
|
||||||
|
# we can only use at most this_new_tokens
|
||||||
|
chunk_len = min(chunk_len, this_new_tokens)
|
||||||
|
seqlen_pos += chunk_len
|
||||||
|
this_new_tokens -= chunk_len
|
||||||
|
|
||||||
|
n_chunks = cdiv(this_new_tokens, chunk_size)
|
||||||
|
for chunk in range(n_chunks):
|
||||||
|
seq_idx.append(req_idx)
|
||||||
|
cu_chunk_seqlen.append(seqlen_pos)
|
||||||
|
chunk_len = min(chunk_size, this_new_tokens)
|
||||||
|
seqlen_pos += chunk_len
|
||||||
|
this_new_tokens -= chunk_len
|
||||||
|
|
||||||
|
assert this_new_tokens == 0
|
||||||
|
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||||
|
|
||||||
|
cu_chunk_seqlen.append(seqlen_pos)
|
||||||
|
|
||||||
|
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||||
|
|
||||||
|
def _build_chunk_metadata_tensors(
|
||||||
|
self,
|
||||||
|
chunk_size: int,
|
||||||
|
common: M,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Compute chunk metadata and return as device tensors.
|
||||||
|
Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p).
|
||||||
|
"""
|
||||||
|
num_reqs = common.num_reqs
|
||||||
|
num_prefills = common.num_prefills
|
||||||
|
num_decode_tokens = common.num_decode_tokens
|
||||||
|
|
||||||
|
num_computed_tokens_cpu = (
|
||||||
|
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||||
|
)
|
||||||
|
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||||
|
num_reqs - num_prefills : num_reqs
|
||||||
|
]
|
||||||
|
query_start_loc_p_cpu = (
|
||||||
|
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||||
|
- num_decode_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||||
|
chunk_size,
|
||||||
|
num_prefills,
|
||||||
|
num_computed_tokens_p_cpu,
|
||||||
|
query_start_loc_p_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = common_attn_metadata.query_start_loc.device
|
||||||
|
cu_chunk_seqlen_p = torch.as_tensor(
|
||||||
|
cu_chunk_seqlen,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
seq_idx_p = torch.as_tensor(
|
||||||
|
seq_idx,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
last_chunk_indices_p = torch.as_tensor(
|
||||||
|
last_chunk_indices,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p
|
||||||
|
|
||||||
def _compute_prefix_caching_block_indices(
|
def _compute_prefix_caching_block_indices(
|
||||||
self,
|
self,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
|||||||
Reference in New Issue
Block a user