From 731a6940e39e84619bbc8db8a794563bb8cc61a5 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Wed, 3 Sep 2025 11:04:00 -0700 Subject: [PATCH] Migrate whisper inputs to TensorSchema (#23505) Signed-off-by: Benji Beck --- vllm/model_executor/models/whisper.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 848b6e0f8..97e8cd6e7 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -4,7 +4,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from contextlib import nullcontext -from typing import Literal, Optional, TypedDict, Union, cast +from typing import Annotated, Literal, Optional, Union, cast import numpy as np import torch @@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription, SupportsV0Only) @@ -111,9 +112,16 @@ ISO639_1_SUPPORTED_LANGS = { } -class WhisperAudioInputs(TypedDict): - input_features: NestedTensors - """Shape: `(batch_size, 128, M)`""" +class WhisperAudioInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - nmb: Number of mel bins + - t: Time frames (M) + """ + + input_features: Annotated[Optional[NestedTensors], + TensorShape("b", "nmb", "t")] class WhisperPositionalEmbedding(nn.Embedding):