[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Cyrus Leung
2025-09-27 16:15:12 +08:00
committed by GitHub
parent 176173989a
commit 27d7638b94
80 changed files with 966 additions and 1139 deletions

View File

@@ -101,6 +101,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@@ -142,6 +145,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
prefix=maybe_prefix(
prefix, "model"))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,