[V1] Refactor model executable interface for all text-only language models (#10374)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -106,15 +106,22 @@ class MambaModel(nn.Module):
|
||||
self.norm_f = RMSNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.embeddings(input_ids)
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
@@ -168,12 +175,16 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
max_batch_size = (_get_graph_batch_size(
|
||||
@@ -194,7 +205,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
state_indices_tensor)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||
mamba_cache_params)
|
||||
mamba_cache_params, inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user