[V1] Refactor model executable interface for multimodal models (#10570)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -356,8 +356,7 @@ def embed_multimodal(
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_token_id: int,
|
||||
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
|
||||
get_multimodal_embeds: Callable[[torch.Tensor], Union[torch.Tensor,
|
||||
List[torch.Tensor]]],
|
||||
multimodal_embeds: Union[torch.Tensor, List[torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Embed token IDs and multimodal inputs and combine their embeddings.
|
||||
@@ -374,8 +373,6 @@ def embed_multimodal(
|
||||
is_text = ~is_multimodal
|
||||
|
||||
text_embeds = get_text_embeds(input_ids[is_text])
|
||||
multimodal_embeds = get_multimodal_embeds(input_ids[is_multimodal])
|
||||
|
||||
merged_embeds = torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
dtype=text_embeds.dtype,
|
||||
|
||||
Reference in New Issue
Block a user