[LMM] Implement merged multimodal processor for whisper (#13278)

This commit is contained in:
Isotr0py
2025-02-23 17:46:03 +08:00
committed by GitHub
parent d5ca2110f1
commit ba5106e519
4 changed files with 144 additions and 77 deletions

View File

@@ -4,15 +4,15 @@ import math
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
import numpy as np
import torch
from torch import nn
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
WhisperProcessor)
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -25,11 +25,14 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.audio import resample_audio
from vllm.sequence import SequenceData
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import SupportsMultiModal, SupportsTranscription
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
@@ -571,72 +574,126 @@ class WhisperModel(nn.Module):
return loaded_params
def get_max_whisper_audio_tokens(ctx: InputContext) -> int:
return ctx.model_config.hf_config.max_source_positions
class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)
def get_hf_processor(self,
sampling_rate: Optional[int] = None
) -> WhisperProcessor:
return self.ctx.get_hf_processor(WhisperProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": 1}
def get_feature_extractor(self) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor()
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def get_max_audio_tokens(self) -> int:
return self.get_hf_config().max_source_positions
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"audio": self.get_max_audio_tokens()}
def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
assert mm_counts["audio"] == 1
num_tokens = get_max_whisper_audio_tokens(ctx)
processor = cached_processor_from_config(ctx.model_config)
chunk_length = processor.feature_extractor.chunk_length
sampling_rate = processor.feature_extractor.sampling_rate
num_samples = chunk_length * sampling_rate
return DummyData(
SequenceData.from_prompt_token_counts((0, num_tokens)),
{"audio": [(np.zeros(num_samples), sampling_rate)]},
)
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs(
prompt_text="<|startoftranscript|>" * num_audios,
mm_data=mm_data,
)
def input_processor_for_whisper(ctx: InputContext, inputs):
multi_modal_data = inputs["encoder"]["multi_modal_data"]
if isinstance(multi_modal_data["audio"], list):
assert len(multi_modal_data["audio"]) == 1
multi_modal_data["audio"] = multi_modal_data["audio"][0]
# Resample and process audio
audio, orig_sr = multi_modal_data["audio"]
processor = cached_processor_from_config(ctx.model_config)
target_sr = processor.feature_extractor.sampling_rate
audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr)
multi_modal_data["audio"] = (audio, target_sr)
# Pre-allocate placeholder tokens in encoder sequence
num_tokens = get_max_whisper_audio_tokens(ctx)
inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens
return inputs
class WhisperMultiModalProcessor(
EncDecMultiModalProcessor[WhisperProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
# Strictly speaking, whisper encoder only accept audio features.
# We create a dummy encoder prompt here which will be padded to
# num_audio_tokens. So that we can create dummy data from this
# for encoder profiling.
return [0]
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
if "labels" in processed_outputs:
processed_outputs["input_ids"] = processed_outputs.pop("labels")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(input_features=MultiModalFieldConfig.batched("audio"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
num_tokens = self.info.get_max_audio_tokens()
return [
PromptReplacement(
modality="audio",
target=[0],
replacement=[0] * num_tokens,
)
]
def input_mapper_for_whisper(
ctx: InputContext,
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
) -> MultiModalKwargs:
if not isinstance(multi_modal_data, list):
multi_modal_data = [multi_modal_data]
assert len(multi_modal_data) == 1
if len(multi_modal_data) == 0:
return MultiModalKwargs()
processor = cached_processor_from_config(ctx.model_config)
sampling_rate = processor.feature_extractor.sampling_rate
audios = [audio for audio, _ in multi_modal_data]
kwargs = processor(audios,
sampling_rate=sampling_rate,
return_tensors="pt")
kwargs["input_features"] = kwargs["input_features"].squeeze(0).to(
ctx.model_config.dtype)
return MultiModalKwargs(kwargs)
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper)
@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_whisper_audio_tokens)
@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor,
info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal):
packed_modules_mapping = {
@@ -724,7 +781,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(input_features)}")
input_features = [feat.to(self.dtype) for feat in input_features]
input_features = torch.cat(
[feat.to(self.dtype) for feat in input_features])
return WhisperAudioInputs(input_features=input_features)