Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -6,27 +6,42 @@ from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
apply_chunking_to_forward)
from transformers import (
BatchFeature,
Blip2Config,
Blip2QFormerConfig,
apply_chunking_to_forward,
)
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptIndexTargets,
PromptInsertion,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .blip import BlipVisionModel
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
@@ -38,6 +53,7 @@ class Blip2ImagePixelInputs(TensorSchema):
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"]
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
@@ -49,6 +65,7 @@ class Blip2ImageEmbeddingInputs(TensorSchema):
- f: Image feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""
type: Literal["image_embeds"]
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
@@ -57,7 +74,6 @@ Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
class Blip2QFormerMultiHeadAttention(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
@@ -78,8 +94,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = (config.hidden_size //
config.num_attention_heads)
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
@@ -91,18 +106,18 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
self.key = nn.Linear(kv_hidden_size, self.all_head_size)
self.value = nn.Linear(kv_hidden_size, self.all_head_size)
self.position_embedding_type = getattr(config,
"position_embedding_type",
"absolute")
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type != "absolute":
raise NotImplementedError("Unsupported position_embedding_type: "
f"{self.position_embedding_type}")
raise NotImplementedError(
f"Unsupported position_embedding_type: {self.position_embedding_type}"
)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
x = x.view(*x.size()[:-1], self.num_attention_heads,
self.attention_head_size)
x = x.view(*x.size()[:-1], self.num_attention_heads, self.attention_head_size)
return x.permute(0, 2, 1, 3)
def forward(
@@ -113,10 +128,8 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(
self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(
self.value(encoder_hidden_states))
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
@@ -125,10 +138,8 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_probs = torch.softmax(attention_scores * self.scaling,
dim=-1)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_probs = torch.softmax(attention_scores * self.scaling, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
@@ -137,20 +148,19 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
context_layer = context_layer.view(*context_layer.size()[:-2],
self.all_head_size)
context_layer = context_layer.view(
*context_layer.size()[:-2], self.all_head_size
)
return context_layer
class Blip2QFormerSelfOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
@@ -165,7 +175,6 @@ class Blip2QFormerSelfOutput(nn.Module):
class Blip2QFormerAttention(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
@@ -202,7 +211,6 @@ class Blip2QFormerAttention(nn.Module):
class Blip2QFormerIntermediate(nn.Module):
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
@@ -216,13 +224,11 @@ class Blip2QFormerIntermediate(nn.Module):
class Blip2QFormerOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
@@ -237,7 +243,6 @@ class Blip2QFormerOutput(nn.Module):
class Blip2QFormerLayer(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
@@ -251,10 +256,12 @@ class Blip2QFormerLayer(nn.Module):
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = Blip2QFormerAttention(config,
quant_config=quant_config,
cache_config=cache_config,
prefix=f"{prefix}.attention")
self.attention = Blip2QFormerAttention(
config,
quant_config=quant_config,
cache_config=cache_config,
prefix=f"{prefix}.attention",
)
self.layer_idx = layer_idx
@@ -264,15 +271,16 @@ class Blip2QFormerLayer(nn.Module):
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=True,
prefix=f"{prefix}.crossattention")
prefix=f"{prefix}.crossattention",
)
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate_query = Blip2QFormerIntermediate(
config, prefix=f"{prefix}.intermediate_query")
self.output_query = Blip2QFormerOutput(config,
prefix=f"{prefix}.output_query")
config, prefix=f"{prefix}.intermediate_query"
)
self.output_query = Blip2QFormerOutput(config, prefix=f"{prefix}.output_query")
def forward(
self,
@@ -305,8 +313,7 @@ class Blip2QFormerLayer(nn.Module):
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text],
dim=1)
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
@@ -317,21 +324,18 @@ class Blip2QFormerLayer(nn.Module):
return layer_output
def feed_forward_chunk(self,
attention_output: torch.Tensor) -> torch.Tensor:
def feed_forward_chunk(self, attention_output: torch.Tensor) -> torch.Tensor:
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(
self, attention_output: torch.Tensor) -> torch.Tensor:
def feed_forward_chunk_query(self, attention_output: torch.Tensor) -> torch.Tensor:
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
class Blip2QFormerEncoder(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
@@ -344,14 +348,18 @@ class Blip2QFormerEncoder(nn.Module):
self.config = config
self.layer = nn.ModuleList([
Blip2QFormerLayer(config,
quant_config=quant_config,
cache_config=cache_config,
layer_idx=layer_idx,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
self.layer = nn.ModuleList(
[
Blip2QFormerLayer(
config,
quant_config=quant_config,
cache_config=cache_config,
layer_idx=layer_idx,
prefix=f"{prefix}.layer.{layer_idx}",
)
for layer_idx in range(config.num_hidden_layers)
]
)
def forward(
self,
@@ -373,7 +381,6 @@ class Blip2QFormerEncoder(nn.Module):
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
class Blip2QFormerModel(nn.Module):
def __init__(
self,
config: Blip2QFormerConfig,
@@ -386,14 +393,15 @@ class Blip2QFormerModel(nn.Module):
self.config = config
self.layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.encoder = Blip2QFormerEncoder(config,
quant_config=quant_config,
cache_config=cache_config,
prefix=f"{prefix}.encoder")
self.encoder = Blip2QFormerEncoder(
config,
quant_config=quant_config,
cache_config=cache_config,
prefix=f"{prefix}.encoder",
)
def forward(
self,
@@ -415,7 +423,6 @@ class Blip2QFormerModel(nn.Module):
class Blip2ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Blip2Config)
@@ -428,7 +435,6 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
@@ -447,16 +453,16 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
image_overrides = mm_options.get("image") if mm_options else None
return {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images,
overrides=image_overrides)
"image": self._get_dummy_images(
width=max_image_size,
height=max_image_size,
num_images=num_images,
overrides=image_overrides,
)
}
class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
@@ -509,11 +515,14 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
]
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
info=Blip2ProcessingInfo,
dummy_inputs=Blip2DummyInputsBuilder)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsQuant):
@MULTIMODAL_REGISTRY.register_processor(
Blip2MultiModalProcessor,
info=Blip2ProcessingInfo,
dummy_inputs=Blip2DummyInputsBuilder,
)
class Blip2ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
):
merge_by_field_config = True
@classmethod
@@ -524,7 +533,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@@ -537,13 +545,15 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.vision_model = BlipVisionModel(config.vision_config, quant_config)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens,
config.qformer_config.hidden_size))
torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)
)
self.qformer = Blip2QFormerModel(config.qformer_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.qformer")
self.qformer = Blip2QFormerModel(
config.qformer_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.qformer",
)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,
@@ -558,10 +568,12 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
self, **kwargs: object
) -> Optional[Blip2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
@@ -570,12 +582,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
if pixel_values is not None:
expected_h = expected_w = self.config.vision_config.image_size
return Blip2ImagePixelInputs(type="pixel_values",
data=pixel_values,
resolve_bindings={
"h": expected_h,
"w": expected_w
})
return Blip2ImagePixelInputs(
type="pixel_values",
data=pixel_values,
resolve_bindings={"h": expected_h, "w": expected_w},
)
if image_embeds is not None:
return Blip2ImageEmbeddingInputs(
@@ -585,34 +596,30 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
def _image_pixels_to_features(
self, vision_model: BlipVisionModel, pixel_values: torch.Tensor
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_model(pixel_values)
return image_features
def _process_image_pixels(self,
inputs: Blip2ImagePixelInputs) -> torch.Tensor:
def _process_image_pixels(self, inputs: Blip2ImagePixelInputs) -> torch.Tensor:
assert self.vision_model is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_model, pixel_values)
def _process_image_input(self,
image_input: Blip2ImageInputs) -> torch.Tensor:
def _process_image_input(self, image_input: Blip2ImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
-1)
query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1)
query_output = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_features,
@@ -623,8 +630,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
@@ -651,7 +657,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
`[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
before they are inputted to the model, so the input processor prepends
dummy tokens (denoted as `50265`), resulting in:
`[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
@@ -664,7 +670,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
Info:
[`Blip2ImageInputs`][vllm.model_executor.models.blip2.Blip2ImageInputs]
"""
@@ -672,10 +678,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
@@ -685,7 +690,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)