[Frontend][Core] Add plumbing to support audio language models (#7446)
This commit is contained in:
@@ -54,41 +54,42 @@ def init_vllm_registered_model(
|
||||
)
|
||||
|
||||
|
||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: BatchedTensors,
|
||||
image_token_id: int) -> torch.Tensor:
|
||||
def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
multimodal_embeddings: BatchedTensors,
|
||||
placeholder_token_id: int) -> torch.Tensor:
|
||||
"""
|
||||
Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
positions in ``inputs_embeds`` corresponding to placeholder image tokens in
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
positions in ``inputs_embeds`` corresponding to placeholder tokens in
|
||||
``input_ids``.
|
||||
|
||||
Note:
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
mask = (input_ids == image_token_id)
|
||||
mask = (input_ids == placeholder_token_id)
|
||||
num_expected_tokens = mask.sum()
|
||||
|
||||
if isinstance(vision_embeddings, torch.Tensor):
|
||||
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
|
||||
if isinstance(multimodal_embeddings, torch.Tensor):
|
||||
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
|
||||
total_tokens = batch_size * batch_tokens
|
||||
if num_expected_tokens != total_tokens:
|
||||
expr = f"{batch_size} x {batch_tokens}"
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {total_tokens} "
|
||||
f"image tokens to {num_expected_tokens} placeholders")
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
|
||||
inputs_embeds[mask] = multimodal_embeddings.view(
|
||||
total_tokens, embed_dim)
|
||||
else:
|
||||
size_per_batch = [t.shape[0] for t in vision_embeddings]
|
||||
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
|
||||
total_tokens = sum(size_per_batch)
|
||||
if num_expected_tokens != total_tokens:
|
||||
expr = ' + '.join(map(str, size_per_batch))
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {total_tokens} "
|
||||
f"image tokens to {num_expected_tokens} placeholders")
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = torch.cat(vision_embeddings)
|
||||
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user