[Kernel] Triton implementation of causal-conv1d for Mamba-based models (#18218)

Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tuan, Hoang-Trong
2025-07-09 15:53:55 -04:00
committed by GitHub
parent 31b96d1c64
commit 47043eb678
15 changed files with 1117 additions and 1142 deletions

View File

@@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(