[Model] Mamba2 causal conv1d Refactor to Split Prefill and Decode Requests for Corresponding Kernels (#17146)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
_seq_idx_to_chunk_indices_offsets)
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
from vllm.platforms import current_platform
|
||||
@@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
last_taken, exhausted, n_heads,
|
||||
d_head, itype):
|
||||
|
||||
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
|
||||
seq_idx, chunk_size)
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
|
||||
Y, new_states = mamba_chunk_scan_combined(
|
||||
X,
|
||||
|
||||
Reference in New Issue
Block a user