[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: NickLucche <nlucches@redhat.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -117,6 +117,9 @@ class MiMoMultiTokenPredictor(nn.Module):
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -158,6 +161,9 @@ class MiMoMTP(nn.Module):
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user