[Core][VLM] Add precise multi-modal placeholder tracking (#8346)

Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
Peter Salas
2024-11-01 16:21:10 -07:00
committed by GitHub
parent d151fde834
commit 6c0b7f548d
53 changed files with 913 additions and 281 deletions

View File

@@ -2,7 +2,6 @@
"""PyTorch Ultravox model."""
import math
from array import array
from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union, cast)
@@ -17,27 +16,27 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import DecoderOnlyInputs, token_inputs
from vllm.inputs.registry import InputContext
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs, NestedTensors
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
NestedTensors)
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, merge_multimodal_embeddings)
init_vllm_registered_model,
merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
@@ -46,13 +45,13 @@ _AUDIO_TOKENS_PER_SECOND = 6.25
class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, 80, M)"""
"""Shape: `(batch_size, num_audios, 80, M)`"""
class UltravoxAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
@@ -79,17 +78,16 @@ def dummy_seq_data_for_ultravox(
seq_len: int,
audio_count: int,
):
audio_placeholder = array(
VLLM_TOKEN_ID_ARRAY_TYPE,
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
audio_length = min(get_ultravox_max_audio_tokens(ctx),
seq_len // audio_count)
# Add a separator between each chunk.
audio_token_ids = (audio_placeholder +
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - len(audio_token_ids))
return SequenceData(audio_token_ids + other_token_ids)
return SequenceData.from_prompt_token_counts(
(_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count),
(0, seq_len - audio_length * audio_count)), {
"audio":
consecutive_placeholder_ranges(num_items=audio_count,
item_size=audio_length)
}
def dummy_audio_for_ultravox(
@@ -107,10 +105,10 @@ def dummy_data_for_ultravox(
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
return (seq_data, mm_dict)
return DummyData(seq_data, mm_dict, ranges)
def input_mapper_for_ultravox(ctx: InputContext, data: object):
@@ -164,6 +162,11 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "audio" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
feature_extractor = whisper_feature_extractor(ctx)
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
@@ -197,7 +200,7 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
@@ -208,7 +211,8 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"audio": ranges})
class StackAudioFrames(nn.Module):
@@ -472,9 +476,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, audio_embeddings,
_AUDIO_PLACEHOLDER_TOKEN)
merge_multimodal_embeddings_from_map(
inputs_embeds, audio_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"])
input_ids = None
else:
inputs_embeds = None