[Kernel] Change interface to Mamba causal_conv1d_update for continuous batching (#8012)

This commit is contained in:
Tyler Michael Smith
2024-09-17 19:44:27 -04:00
committed by GitHub
parent 09deb4721f
commit 8110e44529
7 changed files with 114 additions and 16 deletions

View File

@@ -36,6 +36,10 @@ struct ConvParamsBase {
void *__restrict__ conv_state_ptr;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t *__restrict__ conv_state_indices_ptr;
void *__restrict__ seq_idx_ptr;
// No __restrict__ since initial_states could be the same as final_states.