[Model] Add Ultravox support for multiple audio chunks (#7963)
This commit is contained in:
@@ -29,12 +29,12 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.utils import (filter_weights,
|
||||
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.base import MultiModalInputs, NestedTensors
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
@@ -48,13 +48,14 @@ logger = init_logger(__name__)
|
||||
|
||||
class UltravoxAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""Shape: `(batch_size * num_audios, 80, M)"""
|
||||
data: NestedTensors
|
||||
"""Shape: `(batch_size, num_audios, 80, M)"""
|
||||
|
||||
|
||||
class UltravoxAudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
data: torch.Tensor
|
||||
data: NestedTensors
|
||||
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
|
||||
|
||||
|
||||
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
|
||||
@@ -85,24 +86,33 @@ def dummy_data_for_ultravox(
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
|
||||
audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [
|
||||
_AUDIO_PLACEHOLDER_TOKEN
|
||||
]) * get_ultravox_max_audio_tokens(ctx) * audio_count
|
||||
audio_placeholder = array(
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
|
||||
|
||||
# Add a separator between each chunk.
|
||||
audio_token_ids = (audio_placeholder +
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
|
||||
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - len(audio_token_ids))
|
||||
|
||||
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
|
||||
mm_dict = {
|
||||
"audio":
|
||||
audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count
|
||||
}
|
||||
mm_dict = {"audio": [audio_and_sr] * audio_count}
|
||||
|
||||
return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
|
||||
|
||||
|
||||
def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
||||
if isinstance(data, tuple):
|
||||
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
audio_features = []
|
||||
for audio_input in data:
|
||||
if not isinstance(audio_input, tuple):
|
||||
raise NotImplementedError(
|
||||
f"Unsupported data type: {type(audio_input)}")
|
||||
|
||||
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
|
||||
if sr != feature_extractor.sampling_rate:
|
||||
@@ -121,15 +131,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
||||
# Not enough audio; pad it.
|
||||
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
|
||||
|
||||
return MultiModalInputs({
|
||||
"audio_features":
|
||||
feature_extractor(audio,
|
||||
sampling_rate=sr,
|
||||
padding="longest",
|
||||
return_tensors="pt")["input_features"]
|
||||
})
|
||||
single_audio_features = feature_extractor(
|
||||
audio, sampling_rate=sr, padding="longest",
|
||||
return_tensors="pt")["input_features"]
|
||||
|
||||
raise NotImplementedError(f"Unsupported data type: {type(data)}")
|
||||
# Remove the batch dimension because we're wrapping it in a list.
|
||||
audio_features.append(single_audio_features.squeeze(0))
|
||||
|
||||
return MultiModalInputs({"audio_features": audio_features})
|
||||
|
||||
|
||||
def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
@@ -138,25 +147,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return llm_inputs
|
||||
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
audio_data, sample_rate = multi_modal_data["audio"]
|
||||
audios = multi_modal_data["audio"]
|
||||
if not isinstance(audios, list):
|
||||
audios = [audios]
|
||||
|
||||
audio_length = audio_data.shape[0]
|
||||
if sample_rate != feature_extractor.sampling_rate:
|
||||
# Account for resampling.
|
||||
adjustment = feature_extractor.sampling_rate / sample_rate
|
||||
audio_length = math.ceil(adjustment * audio_length)
|
||||
audio_token_counts = []
|
||||
for audio_data, sample_rate in audios:
|
||||
audio_length = audio_data.shape[0]
|
||||
if sample_rate != feature_extractor.sampling_rate:
|
||||
# Account for resampling.
|
||||
adjustment = feature_extractor.sampling_rate / sample_rate
|
||||
audio_length = math.ceil(adjustment * audio_length)
|
||||
|
||||
feature_extractor_output_length = math.ceil(
|
||||
(audio_length -
|
||||
(feature_extractor.hop_length - 1)) / feature_extractor.hop_length)
|
||||
feature_extractor_output_length = math.ceil(
|
||||
(audio_length - (feature_extractor.hop_length - 1)) /
|
||||
feature_extractor.hop_length)
|
||||
|
||||
uv_config = ctx.get_hf_config(UltravoxConfig)
|
||||
audio_num_tokens = min(
|
||||
max(
|
||||
1,
|
||||
math.ceil(feature_extractor_output_length /
|
||||
(uv_config.stack_factor * 2))),
|
||||
get_ultravox_max_audio_tokens(ctx))
|
||||
audio_token_counts.append(audio_num_tokens)
|
||||
|
||||
uv_config = ctx.get_hf_config(UltravoxConfig)
|
||||
audio_num_tokens = min(
|
||||
max(
|
||||
1,
|
||||
math.ceil(feature_extractor_output_length /
|
||||
(uv_config.stack_factor * 2))),
|
||||
get_ultravox_max_audio_tokens(ctx))
|
||||
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
@@ -164,7 +179,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
|
||||
repeat_count=audio_num_tokens,
|
||||
repeat_count=audio_token_counts,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
@@ -338,45 +353,52 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
||||
raise ValueError("Incorrect type of audio features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
|
||||
# Remove the N dimension until multiple audios are supported.
|
||||
if isinstance(audio_features, torch.Tensor):
|
||||
audio_features = audio_features.squeeze(1)
|
||||
else:
|
||||
audio_features = [t.squeeze(0) for t in audio_features]
|
||||
|
||||
return UltravoxAudioFeatureInputs(type="audio_features",
|
||||
data=audio_features)
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, torch.Tensor):
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio embeds. "
|
||||
f"Got type: {type(audio_embeds)}")
|
||||
|
||||
# Remove the N dimension until multiple audios are supported.
|
||||
audio_embeds = audio_embeds.squeeze(1)
|
||||
|
||||
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
|
||||
data=audio_embeds)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_audio_input(
|
||||
self, audio_input: UltravoxAudioInputs
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
self, audio_input: UltravoxAudioInputs) -> NestedTensors:
|
||||
if audio_input["type"] == "audio_embeds":
|
||||
return audio_input["data"]
|
||||
|
||||
audio_features = audio_input["data"]
|
||||
if isinstance(audio_features, list):
|
||||
# TODO: Batch these through the encoder/projector instead of
|
||||
# serializing them.
|
||||
return [
|
||||
self._audio_features_to_embeddings(
|
||||
features.unsqueeze(0)).squeeze(0)
|
||||
for features in audio_features
|
||||
]
|
||||
else:
|
||||
return self._audio_features_to_embeddings(audio_features)
|
||||
if isinstance(audio_features, torch.Tensor):
|
||||
# Combine the B and N dimensions for the encoder/projector
|
||||
flattened = flatten_bn(audio_features)
|
||||
flattened_embeddings = self._audio_features_to_embeddings(
|
||||
flattened)
|
||||
|
||||
# Restore the original dimensions
|
||||
embeddings = flattened_embeddings.unflatten(
|
||||
0, audio_features.shape[:2])
|
||||
return embeddings
|
||||
|
||||
result = []
|
||||
# TODO: Batch heterogeneous tensors through the encoder/projector
|
||||
for audio_features_item in audio_features:
|
||||
if isinstance(audio_features_item, torch.Tensor):
|
||||
result.append(
|
||||
self._audio_features_to_embeddings(audio_features_item))
|
||||
else:
|
||||
embeddings = [
|
||||
# Add a batch dimension to embed it, then remove it.
|
||||
self._audio_features_to_embeddings(tensor.unsqueeze(0)
|
||||
).squeeze(0)
|
||||
for tensor in audio_features_item
|
||||
]
|
||||
result.append(embeddings)
|
||||
|
||||
return result
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
@@ -393,7 +415,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
||||
with the `input_ids`.
|
||||
|
||||
Args:
|
||||
input_features: A batch of audio inputs, [1, 80, M].
|
||||
audio_features: A batch of audio inputs [B, N, 80, M].
|
||||
"""
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is not None:
|
||||
|
||||
Reference in New Issue
Block a user