Migrate Qwen2 inputs to TensorSchema (#23475)
Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -23,7 +23,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -47,6 +47,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
@@ -54,21 +55,38 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
|
||||
|
||||
# # === Audio Inputs === #
|
||||
class Qwen2AudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
input_features: torch.Tensor
|
||||
"""Shape: `(num_audios, num_mel_bins, 3000)`"""
|
||||
|
||||
feature_attention_mask: torch.Tensor
|
||||
"""Shape: `(num_audios, 3000)`"""
|
||||
|
||||
|
||||
class Qwen2AudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
audio_embeds: list[torch.Tensor]
|
||||
"""Shape: `(num_audio_features, hidden_size)`
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
class Qwen2AudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- na: Number of audios
|
||||
- nmb: Number of mel bins
|
||||
"""
|
||||
type: Literal["audio_features"]
|
||||
input_features: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("na", "nmb", 3000),
|
||||
]
|
||||
|
||||
feature_attention_mask: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("na", 3000),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2AudioEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size
|
||||
- naf: Number of audio features
|
||||
- hs: Hidden size (must match the hidden size of language model
|
||||
backbone)
|
||||
"""
|
||||
type: Literal["audio_embeds"] = "audio_embeds"
|
||||
|
||||
audio_embeds: Annotated[
|
||||
list[torch.Tensor],
|
||||
TensorShape("bn", "naf", "hs"),
|
||||
]
|
||||
|
||||
|
||||
Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]
|
||||
|
||||
Reference in New Issue
Block a user