Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -14,28 +14,33 @@ from transformers import SiglipVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
default_weight_loader,
maybe_remap_kv_scale_name,
)
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
resolve_visual_encoder_outputs)
from .vision import (
VisionEncoderInfo,
VisionFeatureSelectStrategy,
resolve_visual_encoder_outputs,
)
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
return self.get_patch_grid_length()**2
return self.get_patch_grid_length() ** 2
def get_image_size(self) -> int:
return self.vision_config.image_size
@@ -50,7 +55,6 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
@@ -66,19 +70,20 @@ class SiglipVisionEmbeddings(nn.Module):
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size)**2
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = VocabParallelEmbedding(
self.num_positions, self.embed_dim)
self.num_positions, self.embed_dim
)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions, dtype=torch.int64).expand(
(1, -1)),
torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)),
persistent=False,
)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
width: int) -> torch.Tensor:
def interpolate_pos_encoding(
self, embeddings: torch.Tensor, height: int, width: int
) -> torch.Tensor:
"""
This method is an adapted method for SigLIP (due to SigLIP not having
class embedding unlike other ViTs) that allows the model to interpolate
@@ -103,8 +108,8 @@ class SiglipVisionEmbeddings(nn.Module):
height, width = height + 0.1, width + 0.1
patch_pos_embed = position_embeddings.reshape(
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
dim)
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
@@ -115,33 +120,36 @@ class SiglipVisionEmbeddings(nn.Module):
mode="bicubic",
align_corners=False,
)
if (int(height) != patch_pos_embed.shape[-2]
or int(width) != patch_pos_embed.shape[-1]):
raise ValueError("Width or height does not match with "
"the interpolated position embeddings")
if (
int(height) != patch_pos_embed.shape[-2]
or int(width) != patch_pos_embed.shape[-1]
):
raise ValueError(
"Width or height does not match with "
"the interpolated position embeddings"
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = False) -> torch.Tensor:
def forward(
self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(
dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype)
) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
if interpolate_pos_encoding:
embeddings += self.interpolate_pos_encoding(
embeddings, height, width)
embeddings += self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings += self.position_embedding(self.position_ids)
return embeddings
class SiglipAttention(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
@@ -155,9 +163,11 @@ class SiglipAttention(nn.Module):
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
raise ValueError(
f"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
@@ -179,8 +189,9 @@ class SiglipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
self.attn = MultiHeadAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)
def forward(
self,
@@ -197,7 +208,6 @@ class SiglipAttention(nn.Module):
class SiglipMLP(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
@@ -209,15 +219,14 @@ class SiglipMLP(nn.Module):
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
# Special handling for BNB and torchao quantization
if quant_config and quant_config.get_name() in [
"bitsandbytes", "torchao"
]:
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
quantizable = True
else:
# For other quantization, we require the hidden size to be a
# multiple of 64
quantizable = (config.hidden_size % 64 == 0
and config.intermediate_size % 64 == 0)
quantizable = (
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
@@ -239,7 +248,6 @@ class SiglipMLP(nn.Module):
class SiglipEncoderLayer(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
@@ -255,15 +263,13 @@ class SiglipEncoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
@@ -284,7 +290,6 @@ class SiglipEncoderLayer(nn.Module):
class SiglipEncoder(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
@@ -301,12 +306,16 @@ class SiglipEncoder(nn.Module):
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
SiglipEncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
self.layers = nn.ModuleList(
[
SiglipEncoderLayer(
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(num_hidden_layers)
]
)
def forward(
self,
@@ -341,12 +350,12 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
# TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
self.attention = torch.nn.MultiheadAttention(
config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
config.hidden_size, config.num_attention_heads, batch_first=True
)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size = hidden_state.shape[0]
@@ -363,7 +372,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
class SiglipVisionTransformer(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
@@ -399,13 +407,13 @@ class SiglipVisionTransformer(nn.Module):
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
else:
self.post_layernorm = None
self.use_head = (True if not hasattr(config, "vision_use_head") else
config.vision_use_head)
self.use_head = (
True if not hasattr(config, "vision_use_head") else config.vision_use_head
)
if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead(
config=config,
@@ -493,8 +501,7 @@ class SiglipVisionModel(nn.Module):
feature_select_strategy=feature_select_strategy,
)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -507,8 +514,10 @@ class SiglipVisionModel(nn.Module):
for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if (name.startswith("vision_model.post_layernorm")
and self.vision_model.post_layernorm is None):
if (
name.startswith("vision_model.post_layernorm")
and self.vision_model.post_layernorm is None
):
continue
# omit layers when num_hidden_layers_override is set
@@ -518,21 +527,21 @@ class SiglipVisionModel(nn.Module):
continue
# Check if this is a scale parameter that needs remapping first
if name.endswith(
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
# Try to remap the scale name first
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
# Successfully remapped, use the remapped name
param = params_dict[remapped_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(remapped_name)
continue
# If remapping failed, continue with normal processing
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -543,8 +552,7 @@ class SiglipVisionModel(nn.Module):
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params