Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -74,7 +74,9 @@ class PaliGemmaImageEmbeddingInputs(TensorSchema):
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
|
||||
|
||||
|
||||
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, PaliGemmaImageEmbeddingInputs]
|
||||
PaliGemmaImageInputs: TypeAlias = (
|
||||
PaliGemmaImagePixelInputs | PaliGemmaImageEmbeddingInputs
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProjector(nn.Module):
|
||||
@@ -95,7 +97,7 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
|
||||
def get_vision_encoder_info(self):
|
||||
return get_vision_encoder_info(self.get_hf_config())
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_num_image_tokens(
|
||||
@@ -120,7 +122,7 @@ class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
@@ -217,11 +219,11 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
prompt: str | list[int],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> MultiModalInputs:
|
||||
mm_inputs = super().apply(
|
||||
prompt,
|
||||
@@ -273,7 +275,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
@@ -317,7 +319,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Optional[PaliGemmaImageInputs]:
|
||||
) -> PaliGemmaImageInputs | None:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
@@ -386,8 +388,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
@@ -402,7 +404,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> Optional[torch.Tensor]:
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
|
||||
Reference in New Issue
Block a user