[Kernel][Model] Varlen prefill + Prefill chunking support for mamba kernels and Jamba model (#8533)

This commit is contained in:
Mor Zusman
2024-09-30 00:35:58 +03:00
committed by GitHub
parent 6c9ba48fde
commit f13a07b1f8
13 changed files with 1176 additions and 894 deletions

View File

@@ -24,6 +24,7 @@ struct ConvParamsBase {
index_t out_c_stride;
index_t out_l_stride;
int conv_state_len;
index_t conv_state_batch_stride;
index_t conv_state_c_stride;
index_t conv_state_l_stride;
@@ -35,6 +36,10 @@ struct ConvParamsBase {
void *__restrict__ out_ptr;
void *__restrict__ conv_state_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ cache_indices_ptr;
int32_t *__restrict__ cache_seqlens;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
@@ -52,6 +57,11 @@ struct ConvParamsBase {
index_t final_states_batch_stride;
index_t final_states_l_stride;
index_t final_states_c_stride;
void * conv_states_ptr;
index_t conv_states_batch_stride;
index_t conv_states_l_stride;
index_t conv_states_c_stride;
};