[Mamba] - Consolidate Mambas Attention Logic (#28133)

This commit is contained in:
Asaf Joseph Gardin
2025-12-23 22:57:00 +02:00
committed by GitHub
parent 0736f901e7
commit 34916ae37f
5 changed files with 305 additions and 448 deletions

View File

@@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
query_start_loc_p = attn_metadata.query_start_loc_p
BCx, _ = self.in_proj(hidden_states)
@@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp):
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
if has_prefill
else None
)
conv_output_list = []