[Feature] Add vision language model support. (#3042)

This commit is contained in:
xwjiang2010
2024-03-25 14:16:30 -07:00
committed by GitHub
parent f408d05c52
commit 64172a976c
28 changed files with 936 additions and 94 deletions

View File

@@ -250,14 +250,21 @@ class LlamaModel(nn.Module):
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]