Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -24,14 +24,18 @@ from typing import Optional
import torch
from torch import nn
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2Config, Idefics2VisionConfig)
Idefics2Config,
Idefics2VisionConfig,
)
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -67,13 +71,14 @@ class Idefics2VisionEmbeddings(nn.Module):
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def forward(self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
def forward(
self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(target_dtype))
@@ -82,14 +87,14 @@ class Idefics2VisionEmbeddings(nn.Module):
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
1 / self.num_patches_per_side)
position_ids = torch.full(size=(batch_size,
max_nb_patches_h * max_nb_patches_w),
fill_value=0)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
@@ -98,14 +103,15 @@ class Idefics2VisionEmbeddings(nn.Module):
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h,
boundaries,
right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w,
boundaries,
right=True)
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
bucket_coords_w).flatten()
bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)
pos_ids = (
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings += self.position_embedding(position_ids)
@@ -130,12 +136,12 @@ class Idefics2VisionAttention(nn.Module):
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501
f" {self.num_heads}).")
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
assert self.num_heads % tp_size == 0
self.num_heads_per_partition = self.num_heads // tp_size
@@ -156,8 +162,9 @@ class Idefics2VisionAttention(nn.Module):
disable_tp=use_data_parallel,
)
# Use unified MultiHeadAttention with Flash Attention support
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
self.attn = MultiHeadAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)
def forward(
self,
@@ -175,7 +182,6 @@ class Idefics2VisionAttention(nn.Module):
class Idefics2VisionMLP(nn.Module):
def __init__(
self,
config: Idefics2VisionConfig,
@@ -211,7 +217,6 @@ class Idefics2VisionMLP(nn.Module):
class Idefics2EncoderLayer(nn.Module):
def __init__(
self,
config: Idefics2Config,
@@ -225,15 +230,16 @@ class Idefics2EncoderLayer(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
use_data_parallel=use_data_parallel,
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(
config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
@@ -284,13 +290,17 @@ class Idefics2Encoder(nn.Module):
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
Idefics2EncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(num_hidden_layers)
])
self.layers = nn.ModuleList(
[
Idefics2EncoderLayer(
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel,
)
for layer_idx in range(num_hidden_layers)
]
)
def forward(
self,
@@ -313,7 +323,6 @@ class Idefics2Encoder(nn.Module):
class Idefics2VisionTransformer(nn.Module):
def __init__(
self,
config: Idefics2VisionConfig,
@@ -335,7 +344,8 @@ class Idefics2VisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel)
use_data_parallel=use_data_parallel,
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
@@ -345,10 +355,14 @@ class Idefics2VisionTransformer(nn.Module):
)
self.require_post_norm = require_post_norm
self.post_layernorm = nn.LayerNorm(
embed_dim,
eps=config.layer_norm_eps,
) if require_post_norm else nn.Identity()
self.post_layernorm = (
nn.LayerNorm(
embed_dim,
eps=config.layer_norm_eps,
)
if require_post_norm
else nn.Identity()
)
def get_input_embeddings(self):
return self.embeddings
@@ -365,15 +379,13 @@ class Idefics2VisionTransformer(nn.Module):
tgt_sizes=tgt_sizes,
)
if self.use_data_parallel:
encoder_outputs = run_dp_sharded_vision_model(
hidden_states, self.encoder)
encoder_outputs = run_dp_sharded_vision_model(hidden_states, self.encoder)
else:
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -390,8 +402,7 @@ class Idefics2VisionTransformer(nn.Module):
continue
# post_layernorm is optional
if (name.startswith("post_layernorm.")
and not self.require_post_norm):
if name.startswith("post_layernorm.") and not self.require_post_norm:
continue
# omit layers when num_hidden_layers_override is set
@@ -410,8 +421,7 @@ class Idefics2VisionTransformer(nn.Module):
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params