[V1] Refactor model executable interface for all text-only language models (#10374)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang
2024-11-16 21:18:46 -08:00
committed by GitHub
parent 4fd9375028
commit 643ecf7b11
43 changed files with 483 additions and 90 deletions

View File

@@ -292,6 +292,9 @@ class JambaModel(nn.Module):
self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@@ -299,8 +302,12 @@ class JambaModel(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
@@ -381,12 +388,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
config.vocab_size)
self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
@@ -409,7 +420,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_params)
attn_metadata, mamba_cache_params,
inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):