[Model] Define merge_by_field_config MM interface (U-Z) (#26261)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
@@ -51,6 +51,7 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
|
||||
@@ -135,7 +136,10 @@ class WhisperAudioInputs(TensorSchema):
|
||||
- t: Time frames (M)
|
||||
"""
|
||||
|
||||
input_features: Annotated[Optional[NestedTensors], TensorShape("b", "nmb", "t")]
|
||||
input_features: Annotated[
|
||||
Optional[list[torch.Tensor]],
|
||||
TensorShape("b", "nmb", "t"),
|
||||
]
|
||||
|
||||
|
||||
class WhisperEncoderAttention(MultiHeadAttention):
|
||||
@@ -781,6 +785,7 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo
|
||||
class WhisperForConditionalGeneration(
|
||||
nn.Module, SupportsTranscription, SupportsMultiModal
|
||||
):
|
||||
merge_by_field_config = True
|
||||
packed_modules_mapping = {
|
||||
"self_attn.qkv_proj": [
|
||||
"self_attn.q_proj",
|
||||
@@ -936,12 +941,7 @@ class WhisperForConditionalGeneration(
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
|
||||
if input_features is not None:
|
||||
if not isinstance(input_features, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of audio features. "
|
||||
f"Got type: {type(input_features)}"
|
||||
)
|
||||
input_features = torch.cat([feat.to(self.dtype) for feat in input_features])
|
||||
input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
|
||||
|
||||
return WhisperAudioInputs(input_features=input_features)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user