Migrate Qwen2 inputs to TensorSchema (#23475)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Benji Beck
2025-09-06 20:07:31 -07:00
committed by GitHub
parent 558f0907dc
commit 37a6fa95fd
4 changed files with 268 additions and 175 deletions

View File

@@ -25,7 +25,7 @@
from collections.abc import Iterable, Mapping, Sequence
from copy import copy
from functools import partial
from typing import Any, Callable, Optional, Union
from typing import Annotated, Any, Callable, Literal, Optional, Union
import torch
import torch.nn as nn
@@ -41,15 +41,13 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
from vllm.model_executor.models.qwen2_audio import (
Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo,
_get_feat_extract_output_lengths)
Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths)
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -66,9 +64,9 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
@@ -81,6 +79,26 @@ except (ImportError, ModuleNotFoundError):
logger = init_logger(__name__)
class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- na: Number of audios
- nmb: Number of mel bins
- msl: Maximum sequence length
- tsl: Total sequence length
"""
type: Literal["audio_features"]
input_features: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("nmb", "tsl"),
]
feature_attention_mask: Annotated[
torch.Tensor,
TensorShape("na", "msl"),
]
def create_qwen2_5_omni_thinker_field_factory(
spatial_merge_size: int
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
@@ -536,7 +554,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
return torch.concat(mm_input, dim=dim)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]:
self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
input_audio_features = kwargs.pop('input_audio_features', None)
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
@@ -550,7 +568,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
if not isinstance(input_audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_audio_features)}")
return Qwen2AudioFeatureInputs(
return Qwen2_5OmniAudioFeatureInputs(
type="audio_features",
input_features=input_audio_features,
audio_feature_lengths=audio_feature_lengths,
feature_attention_mask=feature_attention_mask)
@@ -633,7 +652,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _process_audio_input(
self,
audio_input: Qwen2AudioFeatureInputs,
audio_input: Qwen2_5OmniAudioFeatureInputs,
audio_hashes: list[str] = None,
cached_audio_features: torch.Tensor = None,
) -> torch.Tensor:
@@ -660,8 +679,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
feature_lens=audio_feature_lengths,
aftercnn_lens=audio_feat_lengths,
)
audio_features = audio_outputs.last_hidden_state
return audio_features.split(audio_output_lengths.tolist())
return audio_outputs.last_hidden_state.split(
audio_output_lengths.tolist())
def _process_image_input(
self,
@@ -707,7 +726,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
)
class Qwen2_5OmniThinkerForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
nn.Module, SupportsMultiModal, SupportsPP,
Qwen2_5OmniConditionalGenerationMixin):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@@ -800,15 +819,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_mm_mapping(self) -> MultiModelKeys:
"""Get module prefix for multimodal models to filter LoRA modules."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector=[], # No explicit connector in this model
tower_model=["visual",
"audio_tower"], # Exclude vision and audio towers
)
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings: