[V1] Refactor model executable interface for all text-only language models (#10374)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang
2024-11-16 21:18:46 -08:00
committed by GitHub
parent 4fd9375028
commit 643ecf7b11
43 changed files with 483 additions and 90 deletions

View File

@@ -78,6 +78,9 @@ class EAGLE(nn.Module):
def sampler(self):
return self.model.sampler
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@@ -86,11 +89,14 @@ class EAGLE(nn.Module):
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
tok_embeds = self.model.model.embed_tokens(input_ids)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.fc(
torch.cat([tok_embeds, previous_hidden_states], dim=-1))
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
@@ -100,7 +106,8 @@ class EAGLE(nn.Module):
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors)
intermediate_tensors=intermediate_tensors,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,