[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Cyrus Leung
2025-09-27 16:15:12 +08:00
committed by GitHub
parent 176173989a
commit 27d7638b94
80 changed files with 966 additions and 1139 deletions

View File

@@ -395,6 +395,9 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,