[Misc] Clean up type annotation for SupportsMultiModal (#14794)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-14 15:59:56 +08:00
committed by GitHub
parent 09269b3127
commit 601bd3268e
27 changed files with 121 additions and 141 deletions

View File

@@ -37,8 +37,7 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -47,7 +46,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@@ -357,8 +356,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
audio_output_lengths.flatten().tolist())
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return None
@@ -368,7 +366,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: