Migrate Qwen2 inputs to TensorSchema (#23475)
Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -25,7 +25,7 @@
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -41,15 +41,13 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
|
||||
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
|
||||
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
|
||||
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
|
||||
from vllm.model_executor.models.qwen2_audio import (
|
||||
Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo,
|
||||
_get_feat_extract_output_lengths)
|
||||
Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths)
|
||||
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -66,9 +64,9 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
@@ -81,6 +79,26 @@ except (ImportError, ModuleNotFoundError):
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- na: Number of audios
|
||||
- nmb: Number of mel bins
|
||||
- msl: Maximum sequence length
|
||||
- tsl: Total sequence length
|
||||
"""
|
||||
type: Literal["audio_features"]
|
||||
input_features: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("nmb", "tsl"),
|
||||
]
|
||||
|
||||
feature_attention_mask: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("na", "msl"),
|
||||
]
|
||||
|
||||
|
||||
def create_qwen2_5_omni_thinker_field_factory(
|
||||
spatial_merge_size: int
|
||||
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
|
||||
@@ -536,7 +554,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
return torch.concat(mm_input, dim=dim)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]:
|
||||
self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
|
||||
input_audio_features = kwargs.pop('input_audio_features', None)
|
||||
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
|
||||
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
|
||||
@@ -550,7 +568,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
if not isinstance(input_audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio input features. "
|
||||
f"Got type: {type(input_audio_features)}")
|
||||
return Qwen2AudioFeatureInputs(
|
||||
return Qwen2_5OmniAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
input_features=input_audio_features,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
feature_attention_mask=feature_attention_mask)
|
||||
@@ -633,7 +652,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: Qwen2AudioFeatureInputs,
|
||||
audio_input: Qwen2_5OmniAudioFeatureInputs,
|
||||
audio_hashes: list[str] = None,
|
||||
cached_audio_features: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -660,8 +679,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
feature_lens=audio_feature_lengths,
|
||||
aftercnn_lens=audio_feat_lengths,
|
||||
)
|
||||
audio_features = audio_outputs.last_hidden_state
|
||||
return audio_features.split(audio_output_lengths.tolist())
|
||||
return audio_outputs.last_hidden_state.split(
|
||||
audio_output_lengths.tolist())
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
@@ -707,7 +726,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
|
||||
)
|
||||
class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
nn.Module, SupportsMultiModal, SupportsPP,
|
||||
Qwen2_5OmniConditionalGenerationMixin):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
@@ -800,15 +819,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""Get module prefix for multimodal models to filter LoRA modules."""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector=[], # No explicit connector in this model
|
||||
tower_model=["visual",
|
||||
"audio_tower"], # Exclude vision and audio towers
|
||||
)
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user