Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -16,28 +16,34 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Fuyu model."""
|
||||
"""PyTorch Fuyu model."""
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
|
||||
FuyuProcessor)
|
||||
from transformers import BatchFeature, FuyuConfig, FuyuImageProcessor, FuyuProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@@ -72,7 +78,6 @@ class FuyuImagePatchInputs(TensorSchema):
|
||||
|
||||
|
||||
class FuyuProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(FuyuConfig)
|
||||
|
||||
@@ -124,12 +129,12 @@ class FuyuProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
image_processor = self.get_image_processor()
|
||||
return ImageSize(width=image_processor.size["width"],
|
||||
height=image_processor.size["height"])
|
||||
return ImageSize(
|
||||
width=image_processor.size["width"], height=image_processor.size["height"]
|
||||
)
|
||||
|
||||
|
||||
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
@@ -139,23 +144,22 @@ class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
"image": self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -179,7 +183,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
image_patches = processed_outputs["image_patches"]
|
||||
processed_outputs["image_patches"] = flatten_bn(image_patches)
|
||||
processed_outputs["patches_per_image"] = torch.tensor(
|
||||
[len(p) for p in image_patches])
|
||||
[len(p) for p in image_patches]
|
||||
)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@@ -206,7 +211,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
|
||||
return dict(
|
||||
image_patches=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", patches_per_image),
|
||||
"image", patches_per_image
|
||||
),
|
||||
patches_per_image=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
@@ -232,8 +238,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
image_tokens + [bos_token_id],
|
||||
@@ -249,9 +254,11 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
|
||||
info=FuyuProcessingInfo,
|
||||
dummy_inputs=FuyuDummyInputsBuilder)
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
FuyuMultiModalProcessor,
|
||||
info=FuyuProcessingInfo,
|
||||
dummy_inputs=FuyuDummyInputsBuilder,
|
||||
)
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
@@ -260,7 +267,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"model.vision_embed_tokens.": "vision_embed_tokens.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@@ -292,10 +300,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[FuyuImagePatchInputs]:
|
||||
image_patches = kwargs.pop("image_patches", None)
|
||||
patches_per_image = kwargs.pop("patches_per_image", None)
|
||||
|
||||
@@ -310,21 +320,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
|
||||
self, image_input: FuyuImagePatchInputs
|
||||
) -> MultiModalEmbeddings:
|
||||
image_patches_flat = image_input["image_patches_flat"]
|
||||
patches_per_image = image_input["patches_per_image"]
|
||||
|
||||
assert self.vision_embed_tokens is not None
|
||||
vision_embeddings_flat, _ = self.vision_embed_tokens(
|
||||
image_patches_flat)
|
||||
vision_embeddings_flat, _ = self.vision_embed_tokens(image_patches_flat)
|
||||
|
||||
return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
@@ -355,10 +364,10 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
hidden_states: torch.Tensor,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.language_model.logits_processor(
|
||||
self.language_model.lm_head, hidden_states)
|
||||
self.language_model.lm_head, hidden_states
|
||||
)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user