[V1] Refactor model executable interface for all text-only language models (#10374)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang
2024-11-16 21:18:46 -08:00
committed by GitHub
parent 4fd9375028
commit 643ecf7b11
43 changed files with 483 additions and 90 deletions

View File

@@ -106,15 +106,22 @@ class MambaModel(nn.Module):
self.norm_f = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(input_ids)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
@@ -168,12 +175,16 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
config.vocab_size)
self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
@@ -194,7 +205,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
state_indices_tensor)
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_params)
mamba_cache_params, inputs_embeds)
return hidden_states