[VLM] Merged multimodal processor for Qwen2-Audio (#11303)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -133,8 +133,8 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor.__is_patched__ = True # type: ignore
|
||||
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
hf_processor = self.ctx.get_hf_processor()
|
||||
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor))
|
||||
hf_processor = self.ctx.get_hf_processor(
|
||||
(LlavaProcessor, PixtralProcessor))
|
||||
|
||||
if isinstance(hf_processor, PixtralProcessor):
|
||||
self._patch_pixtral_processor(hf_processor)
|
||||
|
||||
@@ -34,7 +34,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataDict,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -330,20 +329,27 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return self.ctx.get_hf_processor(num_crops=num_crops)
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _apply_hf_processor(
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._apply_hf_processor(
|
||||
prompt, mm_data, mm_processor_kwargs)
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
|
||||
# which will cause OverflowError when decoding the prompt_ids.
|
||||
# Therefore, we need to do an early replacement here
|
||||
token_ids = processed_outputs['input_ids']
|
||||
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
|
||||
processed_outputs['input_ids'] = token_ids
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_prompt_replacements(
|
||||
|
||||
@@ -19,45 +19,43 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
from functools import cached_property
|
||||
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import Qwen2AudioEncoder
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
|
||||
Qwen2AudioEncoder,
|
||||
Qwen2AudioProcessor)
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.utils import consecutive_placeholder_ranges
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# # === Audio Inputs === #
|
||||
class Qwen2AudioInputs(TypedDict):
|
||||
input_features: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_audios, num_mel_bins, 3000)`
|
||||
"""
|
||||
"""Shape: `(num_audios, num_mel_bins, 3000)`"""
|
||||
|
||||
feature_attention_mask: torch.Tensor
|
||||
"""Shape: `(num_audios, 3000)`
|
||||
"""
|
||||
"""Shape: `(num_audios, 3000)`"""
|
||||
|
||||
|
||||
# === Audio Encoder === #
|
||||
@@ -74,187 +72,114 @@ class Qwen2AudioMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_audios = mm_counts["audio"]
|
||||
max_tokens_per_audio = get_max_qwen2_audio_audio_tokens(ctx)
|
||||
max_llm_audio_tokens = max_tokens_per_audio * num_audios
|
||||
if seq_len - max_llm_audio_tokens - 2 < 0:
|
||||
raise RuntimeError(
|
||||
f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
|
||||
"please increase max_model_len or reduce audio limit by "
|
||||
"--limit-mm-per-prompt.")
|
||||
|
||||
audio_token_index = ctx.model_config.hf_config.audio_token_index
|
||||
|
||||
dummy_seqdata = SequenceData.from_prompt_token_counts(
|
||||
(audio_token_index, max_llm_audio_tokens),
|
||||
(0, seq_len - max_llm_audio_tokens),
|
||||
)
|
||||
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
|
||||
return DummyData(
|
||||
dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, {
|
||||
"audio":
|
||||
consecutive_placeholder_ranges(num_items=num_audios,
|
||||
item_size=max_tokens_per_audio)
|
||||
})
|
||||
|
||||
|
||||
def get_processor(
|
||||
processor_name: str,
|
||||
*args,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Gets a processor for the given model name via HuggingFace.
|
||||
|
||||
Derived from `vllm.transformers_utils.image_processor.get_image_processor`.
|
||||
"""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor
|
||||
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the processor class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
|
||||
if not trust_remote_code:
|
||||
err_msg = (
|
||||
"Failed to load the processor. If the processor is "
|
||||
"a custom processor not yet available in the HuggingFace "
|
||||
"transformers library, consider setting "
|
||||
"`trust_remote_code=True` in LLM or using the "
|
||||
"`--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
and the output length of the audio encoder
|
||||
"""
|
||||
input_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (input_lengths - 2) // 2 + 1
|
||||
return input_lengths, output_lengths
|
||||
feat_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (feat_lengths - 2) // 2 + 1
|
||||
return feat_lengths, output_lengths
|
||||
|
||||
|
||||
def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
|
||||
max_source_position = (
|
||||
ctx.model_config.hf_config.audio_config.max_source_positions)
|
||||
hf_config = ctx.get_hf_config(Qwen2AudioConfig)
|
||||
max_source_position = hf_config.audio_config.max_source_positions
|
||||
output_lengths = (max_source_position - 2) // 2 + 1
|
||||
return output_lengths
|
||||
|
||||
|
||||
def input_processor_for_qwen2_audio(
|
||||
ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "audio" not in multi_modal_data:
|
||||
return inputs
|
||||
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
audios = multi_modal_data["audio"]
|
||||
if not isinstance(audios, list):
|
||||
audios = [audios]
|
||||
def _get_hf_processor(self) -> Qwen2AudioProcessor:
|
||||
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
|
||||
|
||||
if len(audios) == 0:
|
||||
return inputs
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
return self._get_hf_processor().feature_extractor # type: ignore
|
||||
|
||||
processor = cached_get_processor(ctx.model_config.model)
|
||||
resampled_audios = [
|
||||
librosa.resample(audio,
|
||||
orig_sr=sampling_rate,
|
||||
target_sr=processor.feature_extractor.sampling_rate)
|
||||
for audio, sampling_rate in audios
|
||||
]
|
||||
audio_input_lengths = np.array(
|
||||
[min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios])
|
||||
def _get_processor_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
# resample audio to the model's sampling rate
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_items.resample_audios(feature_extractor.sampling_rate)
|
||||
|
||||
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
|
||||
audio_input_lengths)
|
||||
return super()._get_processor_data(mm_items)
|
||||
|
||||
audio_token_index = ctx.model_config.hf_config.audio_token_index
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processor_data = dict(processor_data)
|
||||
audios = processor_data.pop("audios", [])
|
||||
|
||||
input_ids = inputs['prompt_token_ids']
|
||||
if audios:
|
||||
processor_data["audios"] = audios
|
||||
|
||||
new_input_ids = []
|
||||
audio_num = input_ids.count(audio_token_index)
|
||||
assert len(audio_input_lengths) == audio_num, \
|
||||
(f'The text input contains {audio_num} audio tokens, '
|
||||
f'but {len(audio_input_lengths)} audios provided')
|
||||
start = 0
|
||||
for audio_idx in range(audio_num):
|
||||
end = input_ids.index(audio_token_index, start)
|
||||
new_input_ids.extend(input_ids[start:end]) # text part
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_processor_kwargs = dict(
|
||||
**mm_processor_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
else:
|
||||
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
|
||||
pass
|
||||
|
||||
new_input_ids.extend([audio_token_index] *
|
||||
audio_output_lengths[audio_idx])
|
||||
start = end + 1
|
||||
new_input_ids.extend(input_ids[start:])
|
||||
return super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=new_input_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
placeholder = hf_config.audio_token_index
|
||||
|
||||
feature_attention_mask = hf_inputs.get("feature_attention_mask")
|
||||
if feature_attention_mask is None:
|
||||
audio_output_lengths = []
|
||||
else:
|
||||
_, audio_output_lengths = _get_feat_extract_output_lengths(
|
||||
feature_attention_mask.sum(-1))
|
||||
|
||||
def input_mapper_for_qwen2_audio(
|
||||
ctx: InputContext,
|
||||
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
|
||||
) -> MultiModalKwargs:
|
||||
"""Input mapper for Qwen2-Audio."""
|
||||
if not isinstance(multi_modal_data, list):
|
||||
multi_modal_data = [multi_modal_data]
|
||||
def get_replacement_qwen2_audio(item_idx: int):
|
||||
return [placeholder] * audio_output_lengths[item_idx]
|
||||
|
||||
if len(multi_modal_data) == 0:
|
||||
return MultiModalKwargs()
|
||||
|
||||
processor = cached_get_processor(ctx.model_config.model)
|
||||
audio_feature_extractor = processor.feature_extractor
|
||||
if audio_feature_extractor is None:
|
||||
raise RuntimeError(
|
||||
"No HuggingFace audio_feature_extractor is available "
|
||||
"to process the audio object")
|
||||
|
||||
try:
|
||||
resampled_audios = [
|
||||
librosa.resample(
|
||||
audio,
|
||||
orig_sr=sampling_rate,
|
||||
target_sr=processor.feature_extractor.sampling_rate)
|
||||
for audio, sampling_rate in multi_modal_data
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=[placeholder],
|
||||
replacement=get_replacement_qwen2_audio,
|
||||
)
|
||||
]
|
||||
batch_data = audio_feature_extractor(resampled_audios,
|
||||
sampling_rate=16000,
|
||||
return_attention_mask=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt").data
|
||||
batch_data["feature_attention_mask"] = batch_data.pop("attention_mask")
|
||||
except Exception:
|
||||
logger.error("Failed to process audio (%s)", multi_modal_data)
|
||||
raise
|
||||
|
||||
return MultiModalKwargs(batch_data)
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
audio_len = get_max_qwen2_audio_audio_tokens(self.ctx)
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
audio = np.zeros(audio_len)
|
||||
data = {"audio": [audio] * audio_count}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|AUDIO|>" * audio_count,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_audio)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_audio)
|
||||
@MULTIMODAL_REGISTRY.register_input_mapper("audio",
|
||||
input_mapper_for_qwen2_audio)
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", get_max_qwen2_audio_audio_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
@@ -289,9 +214,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return get_sampler()
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self,
|
||||
mm_input: Union[torch.Tensor,
|
||||
List[torch.Tensor]],
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import math
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||
|
||||
@@ -25,11 +25,11 @@ from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataDict,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
@@ -61,8 +61,8 @@ def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
|
||||
|
||||
|
||||
def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
|
||||
return cached_feature_extractor(
|
||||
ctx.get_hf_config(UltravoxConfig).audio_model_id)
|
||||
hf_config = ctx.get_hf_config(UltravoxConfig)
|
||||
return cached_feature_extractor(hf_config.audio_model_id)
|
||||
|
||||
|
||||
def get_ultravox_max_audio_tokens(ctx: InputContext):
|
||||
@@ -73,72 +73,71 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
|
||||
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
return self._get_hf_processor().audio_processor.feature_extractor
|
||||
|
||||
def _resample_audio(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sr: int,
|
||||
) -> Dict[str, Union[np.ndarray, int]]:
|
||||
# resample audio to the model's sampling rate
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
if sr != feature_extractor.sampling_rate:
|
||||
try:
|
||||
import librosa
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install vllm[audio] for audio support.") from exc
|
||||
audio = librosa.resample(audio,
|
||||
orig_sr=sr,
|
||||
target_sr=feature_extractor.sampling_rate)
|
||||
sr = feature_extractor.sampling_rate
|
||||
return {"audio": audio, "sampling_rate": sr}
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
if not mm_data or not mm_data.get("audio", None):
|
||||
return super()._apply_hf_processor(prompt, mm_data,
|
||||
mm_processor_kwargs)
|
||||
|
||||
audio_data = mm_data["audio"]
|
||||
if not isinstance(audio_data, list):
|
||||
audio_data = [audio_data]
|
||||
|
||||
# Ultravox processor doesn't support multiple inputs,
|
||||
# therefore we need to input text and audio one by one
|
||||
tokenizer = self._get_tokenizer()
|
||||
audio_features, audio_token_len = [], []
|
||||
processed_inputs = {}
|
||||
for audio, sr in audio_data:
|
||||
data = self._resample_audio(audio, sr)
|
||||
processed_inputs = super()._apply_hf_processor(
|
||||
prompt, data, mm_processor_kwargs)
|
||||
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
|
||||
skip_special_tokens=False)
|
||||
audio_features.append(
|
||||
processed_inputs.pop("audio_values").squeeze(0))
|
||||
audio_token_len.append(
|
||||
processed_inputs.pop("audio_token_len").item())
|
||||
|
||||
return dict(
|
||||
**processed_inputs,
|
||||
audio_features=audio_features,
|
||||
audio_token_len=audio_token_len,
|
||||
)
|
||||
hf_processor = self._get_hf_processor()
|
||||
return hf_processor.audio_processor.feature_extractor # type: ignore
|
||||
|
||||
def _get_processor_data(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# Ultravox uses "audio" instead of "audios" as calling keyword
|
||||
processor_data, passthrough_data = super()._get_processor_data(mm_data)
|
||||
if "audios" in processor_data:
|
||||
processor_data["audio"] = processor_data.pop("audios")
|
||||
return processor_data, passthrough_data
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
# resample audio to the model's sampling rate
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_items.resample_audios(feature_extractor.sampling_rate)
|
||||
|
||||
return super()._get_processor_data(mm_items)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processor_data = dict(processor_data)
|
||||
audios = processor_data.pop("audios", [])
|
||||
|
||||
if not audios:
|
||||
return super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_processor_kwargs = dict(
|
||||
**mm_processor_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
# Already resampled by _get_processor_data
|
||||
assert is_list_of(audios, np.ndarray)
|
||||
|
||||
# Ultravox processor doesn't support multiple inputs,
|
||||
# therefore we need to input text and audio one by one
|
||||
audio_features, audio_token_len = [], []
|
||||
shared_outputs = {}
|
||||
for audio in audios:
|
||||
# NOTE: Ultravox processor accepts "audio" instead of "audios"
|
||||
item_processor_data = dict(**processor_data, audio=audio)
|
||||
|
||||
item_outputs = super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=item_processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
audio_features.append(item_outputs.pop("audio_values")[0])
|
||||
audio_token_len.append(item_outputs.pop("audio_token_len").item())
|
||||
shared_outputs = item_outputs
|
||||
|
||||
combined_outputs = dict(
|
||||
**shared_outputs,
|
||||
audio_features=audio_features,
|
||||
audio_token_len=audio_token_len,
|
||||
)
|
||||
return BatchFeature(combined_outputs)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
@@ -147,7 +146,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
placeholder = hf_processor.audio_token_replacement
|
||||
placeholder = hf_processor.audio_token_replacement # type: ignore
|
||||
|
||||
def get_replacement_ultravox(item_idx: int):
|
||||
audio_token_len = hf_inputs["audio_token_len"][item_idx]
|
||||
@@ -171,7 +170,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
audio = np.zeros(audio_len)
|
||||
data = {"audio": [(audio, sampling_rate)] * audio_count}
|
||||
data = {"audio": [audio] * audio_count}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|audio|>" * audio_count,
|
||||
|
||||
Reference in New Issue
Block a user