Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,23 +11,39 @@ from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargsItems,
|
||||
MultiModalUUIDDict)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptIndexTargets,
|
||||
PromptInsertion, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalInputs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
MultiModalDataItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptIndexTargets,
|
||||
PromptInsertion,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -41,6 +57,7 @@ class PaliGemmaImagePixelInputs(TensorSchema):
|
||||
- h: Height
|
||||
- w: Width
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||
|
||||
@@ -52,16 +69,15 @@ class PaliGemmaImageEmbeddingInputs(TensorSchema):
|
||||
- ifs: Image feature size
|
||||
- hs: Hidden size (must match language model backbone)
|
||||
"""
|
||||
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
|
||||
|
||||
|
||||
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
|
||||
PaliGemmaImageEmbeddingInputs]
|
||||
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, PaliGemmaImageEmbeddingInputs]
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, vision_hidden_size: int, projection_dim: int):
|
||||
super().__init__()
|
||||
|
||||
@@ -73,7 +89,6 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
||||
|
||||
|
||||
class PaliGemmaProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(PaliGemmaConfig)
|
||||
|
||||
@@ -97,9 +112,7 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
|
||||
|
||||
class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
@@ -118,17 +131,16 @@ class PaliGemmaDummyInputsBuilder(
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
"image": self._get_dummy_images(
|
||||
width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProcessor(
|
||||
BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
|
||||
|
||||
class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -171,7 +183,8 @@ class PaliGemmaMultiModalProcessor(
|
||||
|
||||
def get_insertion(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems)
|
||||
)
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
num_image_tokens = images.get_feature_size(item_idx)
|
||||
@@ -196,7 +209,8 @@ class PaliGemmaMultiModalProcessor(
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target=PromptIndexTargets.prefix(
|
||||
[bos_token_id] if tokenizer.add_bos_token else []),
|
||||
[bos_token_id] if tokenizer.add_bos_token else []
|
||||
),
|
||||
insertion=get_insertion,
|
||||
)
|
||||
]
|
||||
@@ -209,11 +223,13 @@ class PaliGemmaMultiModalProcessor(
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> MultiModalInputs:
|
||||
mm_inputs = super().apply(prompt,
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_uuids=mm_uuids)
|
||||
mm_inputs = super().apply(
|
||||
prompt,
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
@@ -231,9 +247,9 @@ class PaliGemmaMultiModalProcessor(
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PaliGemmaMultiModalProcessor,
|
||||
info=PaliGemmaProcessingInfo,
|
||||
dummy_inputs=PaliGemmaDummyInputsBuilder)
|
||||
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
dummy_inputs=PaliGemmaDummyInputsBuilder,
|
||||
)
|
||||
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -253,7 +269,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@@ -270,13 +287,15 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.vision_tower = SiglipVisionModel(config.vision_config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "vision_tower"))
|
||||
self.vision_tower = SiglipVisionModel(
|
||||
config.vision_config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
self.multi_modal_projector = PaliGemmaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
projection_dim=config.vision_config.projection_dim)
|
||||
projection_dim=config.vision_config.projection_dim,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
@@ -293,10 +312,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.language_model.logits_processor.scale *= logit_scale
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[PaliGemmaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
@@ -307,12 +328,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
|
||||
h = w = self.config.vision_config.image_size
|
||||
return PaliGemmaImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
resolve_bindings={
|
||||
"h": h,
|
||||
"w": w
|
||||
})
|
||||
return PaliGemmaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=pixel_values,
|
||||
resolve_bindings={"h": h, "w": w},
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = flatten_bn(image_embeds, concat=True)
|
||||
@@ -329,7 +349,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
vision_tower: SiglipVisionModel,
|
||||
pixel_values: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
target_dtype = vision_tower.get_input_embeddings().weight.dtype
|
||||
image_features = vision_tower(pixel_values.to(dtype=target_dtype))
|
||||
|
||||
@@ -339,7 +358,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
image_input: PaliGemmaImageInputs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
@@ -355,8 +373,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
@@ -365,19 +382,20 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
|
||||
return vision_embeddings
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object) -> IntermediateTensors:
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -387,7 +405,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
Reference in New Issue
Block a user