Migrate Qwen2 inputs to TensorSchema (#23475)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Benji Beck
2025-09-06 20:07:31 -07:00
committed by GitHub
parent 558f0907dc
commit 37a6fa95fd
4 changed files with 268 additions and 175 deletions

View File

@@ -23,7 +23,7 @@
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import torch
import torch.nn as nn
@@ -47,6 +47,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
@@ -54,21 +55,38 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model,
# # === Audio Inputs === #
class Qwen2AudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
input_features: torch.Tensor
"""Shape: `(num_audios, num_mel_bins, 3000)`"""
feature_attention_mask: torch.Tensor
"""Shape: `(num_audios, 3000)`"""
class Qwen2AudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
audio_embeds: list[torch.Tensor]
"""Shape: `(num_audio_features, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
class Qwen2AudioFeatureInputs(TensorSchema):
"""
Dimensions:
- na: Number of audios
- nmb: Number of mel bins
"""
type: Literal["audio_features"]
input_features: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("na", "nmb", 3000),
]
feature_attention_mask: Annotated[
torch.Tensor,
TensorShape("na", 3000),
]
class Qwen2AudioEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size
- naf: Number of audio features
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["audio_embeds"] = "audio_embeds"
audio_embeds: Annotated[
list[torch.Tensor],
TensorShape("bn", "naf", "hs"),
]
Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]