[V1] Initial support of multimodal models for V1 re-arch (#10699)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang
2024-12-08 04:50:51 -08:00
committed by GitHub
parent fd57d2b534
commit a11f326528
11 changed files with 283 additions and 68 deletions

View File

@@ -409,16 +409,42 @@ def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int,
placeholder_token_id: Union[int, List[int]],
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
- T is text token
- S is image start token
- I is image embedding token
- B is image break token
- E is image end token.
Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of
input_ids for a correct embedding merge.
Note:
This updates ``inputs_embeds`` in place.
"""
if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor(placeholder_token_id,
device=input_ids.device)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),