[V1] Refactor model executable interface for all text-only language models (#10374)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -578,6 +578,9 @@ class QWenModel(nn.Module):
|
||||
quant_config=quant_config) if hasattr(
|
||||
config, "visual") else None
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.wte(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -586,6 +589,7 @@ class QWenModel(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
pixel_values: Optional[QwenImageInputs],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
img_pos = None
|
||||
# If pixel / visual embeddings are provided, this is a visual model
|
||||
@@ -606,6 +610,10 @@ class QWenModel(nn.Module):
|
||||
)
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.wte(input_ids)
|
||||
# Merge the image embeddings into the hidden states if actually have
|
||||
# visual features and the corresponding image tokens
|
||||
@@ -915,6 +923,9 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
)
|
||||
return None
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -922,7 +933,8 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
@@ -932,7 +944,7 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
pixel_values)
|
||||
pixel_values, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
||||
Reference in New Issue
Block a user