[V1] Refactor model executable interface for all text-only language models (#10374)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -78,6 +78,9 @@ class EAGLE(nn.Module):
|
||||
def sampler(self):
|
||||
return self.model.sampler
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -86,11 +89,14 @@ class EAGLE(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
tok_embeds = self.model.model.embed_tokens(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = self.fc(
|
||||
torch.cat([tok_embeds, previous_hidden_states], dim=-1))
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
|
||||
|
||||
@@ -100,7 +106,8 @@ class EAGLE(nn.Module):
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user