[Bugfix] Token type and position embeddings fail to be applied to inputs_embeds (#25922)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-01 00:23:12 +08:00
committed by GitHub
parent ef283548f7
commit 9f1c4ecaf2
2 changed files with 14 additions and 9 deletions

View File

@@ -56,11 +56,13 @@ class RobertaEmbedding(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)