[V1] Initial support of multimodal models for V1 re-arch (#10699)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -409,16 +409,42 @@ def merge_multimodal_embeddings(
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
multimodal_embeddings: NestedTensors,
|
||||
placeholder_token_id: int,
|
||||
placeholder_token_id: Union[int, List[int]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
positions in ``inputs_embeds`` corresponding to placeholder tokens in
|
||||
``input_ids``.
|
||||
|
||||
``placeholder_token_id`` can be a list of token ids (e.g, token ids
|
||||
of img_start, img_break, and img_end tokens) when needed: This means
|
||||
the order of these tokens in the ``input_ids`` MUST MATCH the order of
|
||||
their embeddings in ``multimodal_embeddings`` since we need to
|
||||
slice-merge instead of individually scattering.
|
||||
|
||||
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
|
||||
- T is text token
|
||||
- S is image start token
|
||||
- I is image embedding token
|
||||
- B is image break token
|
||||
- E is image end token.
|
||||
|
||||
Then the image embeddings (that correspond to I's) from vision encoder
|
||||
must be padded with embeddings of S, B, and E in the same order of
|
||||
input_ids for a correct embedding merge.
|
||||
|
||||
Note:
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
if isinstance(placeholder_token_id, list):
|
||||
placeholder_token_id = torch.tensor(placeholder_token_id,
|
||||
device=input_ids.device)
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds,
|
||||
torch.isin(input_ids, placeholder_token_id),
|
||||
multimodal_embeddings,
|
||||
)
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds,
|
||||
(input_ids == placeholder_token_id),
|
||||
|
||||
Reference in New Issue
Block a user