Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -11,34 +11,44 @@ from torch import nn
from transformers import BatchFeature, PretrainedConfig
from transformers.models.cohere2_vision import Cohere2VisionConfig
from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import ( # noqa: E501
get_optimal_tiled_canvas)
get_optimal_tiled_canvas,
)
from transformers.models.cohere2_vision.processing_cohere2_vision import (
Cohere2VisionProcessor)
Cohere2VisionProcessor,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import MulAndSilu
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalFieldConfig,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalFieldConfig,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
class Cohere2VisionImagePixelInputs(TensorSchema):
@@ -67,7 +77,7 @@ class Cohere2VisionImagePixelInputs(TensorSchema):
class Cohere2VisionMultiModalProjector(nn.Module):
"""Multimodal projector that maps vision features to text embedding space.
Uses pixel shuffle downsampling followed by SwiGLU activation.
"""
@@ -76,8 +86,7 @@ class Cohere2VisionMultiModalProjector(nn.Module):
self.downsample_factor = config.downsample_factor
# Input dimension after pixel shuffle downsampling
input_dim = config.vision_config.hidden_size * (
config.downsample_factor**2)
input_dim = config.vision_config.hidden_size * (config.downsample_factor**2)
# MergedColumnParallelLinear expects the intermediate size to be a list
# of sizes, so that it will load the weights as two separate linear
# layers before applying any parallelism.
@@ -110,28 +119,26 @@ class Cohere2VisionMultiModalProjector(nn.Module):
def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor:
"""Apply pixel shuffle downsampling to reduce spatial dimensions.
Args:
image_features: Input tensor of shape [B, S, D] where S = H*W
Returns:
Downsampled tensor with increased channel dimension
"""
height = width = int(image_features.shape[1]**0.5)
height = width = int(image_features.shape[1] ** 0.5)
x = image_features.reshape(image_features.shape[0], width, height, -1)
n, h, w, c = x.size()
scale_factor = 1. / self.downsample_factor
scale_factor = 1.0 / self.downsample_factor
nh = int(h * scale_factor)
nw = int(w * scale_factor)
x = x.reshape(n, nh, self.downsample_factor, nw,
self.downsample_factor, c)
x = x.reshape(n, nh, self.downsample_factor, nw, self.downsample_factor, c)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.reshape(n, nh, nw, -1)
return x
class Cohere2VisionProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> Cohere2VisionConfig:
return self.ctx.get_hf_config(Cohere2VisionConfig)
@@ -146,8 +153,8 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
height = image_processor.size['height']
width = image_processor.size['width']
height = image_processor.size["height"]
width = image_processor.size["width"]
max_patches = image_processor.max_patches
return ImageSize(height=height * max_patches, width=width)
@@ -196,8 +203,8 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo):
class Cohere2VisionDummyInputsBuilder(
BaseDummyInputsBuilder[Cohere2VisionProcessingInfo]):
BaseDummyInputsBuilder[Cohere2VisionProcessingInfo]
):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@@ -213,23 +220,23 @@ class Cohere2VisionDummyInputsBuilder(
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
image_size = \
self.info.get_image_size_with_most_features()
image_size = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image":
self._get_dummy_images(width=image_size.width,
height=image_size.height,
num_images=num_images,
overrides=image_overrides)
"image": self._get_dummy_images(
width=image_size.width,
height=image_size.height,
num_images=num_images,
overrides=image_overrides,
)
}
class Cohere2VisionMultiModalProcessor(
BaseMultiModalProcessor[Cohere2VisionProcessingInfo]):
BaseMultiModalProcessor[Cohere2VisionProcessingInfo]
):
def _call_hf_processor(
self,
prompt: str,
@@ -245,22 +252,26 @@ class Cohere2VisionMultiModalProcessor(
)
# Ensure num_patches is available for proper tensor splitting
if "num_patches" not in processed_outputs and (
images := mm_data.get("images")) is not None:
if (
"num_patches" not in processed_outputs
and (images := mm_data.get("images")) is not None
):
hf_processor = self.info.get_hf_processor(**mm_kwargs)
# Fallback calculation if HF processor didn't provide num_patches
parsed_images = self._get_data_parser().parse_mm_data({
"image":
images
}).get_items("image", ImageProcessorItems)
parsed_images = (
self._get_data_parser()
.parse_mm_data({"image": images})
.get_items("image", ImageProcessorItems)
)
num_patches = [
self.info.get_num_patches(
image_width=parsed_images.get_image_size(i).width,
image_height=parsed_images.get_image_size(i).height,
processor=hf_processor,
) for i in range(len(parsed_images))
)
for i in range(len(parsed_images))
]
processed_outputs["num_patches"] = torch.tensor(num_patches)
@@ -273,8 +284,7 @@ class Cohere2VisionMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@@ -301,8 +311,7 @@ class Cohere2VisionMultiModalProcessor(
image_height=image_size.height,
processor=hf_processor,
)
patch_tokens = (image_token * img_tokens_per_tile +
img_line_break_token)
patch_tokens = image_token * img_tokens_per_tile + img_line_break_token
repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}"
return PromptUpdateDetails.select_text(repl, image_token)
@@ -319,9 +328,9 @@ class Cohere2VisionMultiModalProcessor(
@MULTIMODAL_REGISTRY.register_processor(
Cohere2VisionMultiModalProcessor,
info=Cohere2VisionProcessingInfo,
dummy_inputs=Cohere2VisionDummyInputsBuilder)
class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
dummy_inputs=Cohere2VisionDummyInputsBuilder,
)
class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
@@ -330,7 +339,8 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.multi_modal_projector.": "multi_modal_projector.",
"model.language_model.": "language_model.model.",
"lm_head.": "language_model.lm_head.",
})
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -342,37 +352,39 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config)
self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
prefix=maybe_prefix(
prefix, "vision_tower"))
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.vocab_size = config.text_config.vocab_size
self.multi_modal_projector = \
Cohere2VisionMultiModalProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.multi_modal_projector = Cohere2VisionMultiModalProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=config.text_config.architectures)
architectures=config.text_config.architectures,
)
@property
def dtype(self):
return next(self.parameters()).dtype
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _process_image_input(self, image_input: Cohere2VisionImagePixelInputs,
**kwargs) -> list[torch.Tensor]:
def _process_image_input(
self, image_input: Cohere2VisionImagePixelInputs, **kwargs
) -> list[torch.Tensor]:
"""Process image pixels through vision tower and projector.
Args:
image_input: Validated image input containing pixel values and
image_input: Validated image input containing pixel values and
patch counts
Returns:
List of flattened image embeddings, one per image
"""
@@ -388,17 +400,15 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = self.multi_modal_projector(image_features)
# Split and flatten embeddings per image
return [
e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())
]
return [e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())]
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Cohere2VisionImagePixelInputs]:
self, **kwargs: object
) -> Optional[Cohere2VisionImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, \
"Cohere2Vision does not support image_embeds."
assert image_embeds is None, "Cohere2Vision does not support image_embeds."
if pixel_values is None:
return None
@@ -410,25 +420,26 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size,
})
},
)
def _patch_quant_config(self, config: PretrainedConfig,
quant_config: QuantizationConfig):
def _patch_quant_config(
self, config: PretrainedConfig, quant_config: QuantizationConfig
):
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
llm_quant_config = getattr(text_config, "quantization_config",
None)
if (not quant_config.modules_to_not_convert) and (llm_quant_config
is not None):
llm_quant_config = getattr(text_config, "quantization_config", None)
if (not quant_config.modules_to_not_convert) and (
llm_quant_config is not None
):
quant_config.modules_to_not_convert.append("vision_tower")
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []