[Model] Add PaliGemma (#5189)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Roger Wang
2024-07-06 18:25:50 -07:00
committed by GitHub
parent 9389380015
commit 6206dcb29e
6 changed files with 557 additions and 2 deletions

View File

@@ -268,16 +268,22 @@ class GemmaModel(nn.Module):
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.normalizer
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]