[Bugfix][mamba] Fix type annotation of Mamba2Metadata (#22787)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-08-13 06:07:09 -07:00
committed by GitHub
parent 6b794c756c
commit fceafaf582
2 changed files with 26 additions and 21 deletions

View File

@@ -473,12 +473,12 @@ class MambaMixer2(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx
chunk_indices_p = attn_metadata.chunk_indices
chunk_offsets_p = attn_metadata.chunk_offsets
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state