[Model] Mamba2 causal conv1d Refactor to Split Prefill and Decode Requests for Corresponding Kernels (#17146)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang
2025-05-06 20:59:30 -04:00
committed by GitHub
parent 6de3e13413
commit 18dd5e01f2
8 changed files with 151 additions and 123 deletions

View File

@@ -6,7 +6,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_seq_idx_to_chunk_indices_offsets)
_query_start_loc_to_chunk_indices_offsets)
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
from vllm.platforms import current_platform
@@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
last_taken, exhausted, n_heads,
d_head, itype):
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
Y, new_states = mamba_chunk_scan_combined(
X,