[DeepSeek] Fix DeepSeek V3.2 Rope Embedding (#28968)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
@@ -24,6 +24,7 @@ class MLAModules:
|
||||
q_b_proj: torch.nn.Module | None
|
||||
q_proj: torch.nn.Module | None
|
||||
indexer: torch.nn.Module | None
|
||||
indexer_rotary_emb: torch.nn.Module | None
|
||||
is_sparse: bool
|
||||
topk_indices_buffer: torch.Tensor | None
|
||||
|
||||
@@ -80,6 +81,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||
self.rotary_emb = mla_modules.rotary_emb
|
||||
self.o_proj = mla_modules.o_proj
|
||||
self.indexer = mla_modules.indexer
|
||||
self.indexer_rope_emb = mla_modules.indexer_rotary_emb
|
||||
self.is_sparse = mla_modules.is_sparse
|
||||
|
||||
if self.indexer is not None:
|
||||
@@ -153,7 +155,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||
)
|
||||
|
||||
if self.indexer and self.is_sparse:
|
||||
_topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)
|
||||
_topk_indices = self.indexer(
|
||||
hidden_states, q_c, positions, self.indexer_rope_emb
|
||||
)
|
||||
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
|
||||
Reference in New Issue
Block a user