[Bugfix] Fix dtype mismatch in PaliGemma (#6367)

This commit is contained in:
Cyrus Leung
2024-07-12 23:22:18 +08:00
committed by GitHub
parent aea19f0989
commit 024ad87cdc
3 changed files with 12 additions and 5 deletions

View File

@@ -277,6 +277,7 @@ class GemmaModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None: