[DeepSeek] Fix DeepSeek V3.2 Rope Embedding (#28968)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Yongye Zhu
2025-11-19 16:30:04 -05:00
committed by GitHub
parent 613abb50d5
commit 88f5b19f0b
2 changed files with 17 additions and 3 deletions

View File

@@ -24,6 +24,7 @@ class MLAModules:
q_b_proj: torch.nn.Module | None
q_proj: torch.nn.Module | None
indexer: torch.nn.Module | None
indexer_rotary_emb: torch.nn.Module | None
is_sparse: bool
topk_indices_buffer: torch.Tensor | None
@@ -80,6 +81,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
self.rotary_emb = mla_modules.rotary_emb
self.o_proj = mla_modules.o_proj
self.indexer = mla_modules.indexer
self.indexer_rope_emb = mla_modules.indexer_rotary_emb
self.is_sparse = mla_modules.is_sparse
if self.indexer is not None:
@@ -153,7 +155,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)
_topk_indices = self.indexer(
hidden_states, q_c, positions, self.indexer_rope_emb
)
attn_out = self.mla_attn(
q,