[Mamba] - Consolidate Mambas Attention Logic (#28133)
This commit is contained in:
committed by
GitHub
parent
0736f901e7
commit
34916ae37f
@@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp):
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
|
||||
BCx, _ = self.in_proj(hidden_states)
|
||||
|
||||
@@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
|
||||
if has_prefill
|
||||
else None
|
||||
)
|
||||
|
||||
conv_output_list = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user