Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -14,18 +14,20 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.ovis import AIMv2Config
class AIMv2SwiGLUFFN(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix: str):
def __init__(
self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
):
super().__init__()
hidden_features = config.intermediate_size
in_features = config.hidden_size
@@ -55,7 +57,6 @@ class AIMv2SwiGLUFFN(nn.Module):
class AIMv2PatchEmbed(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.proj = nn.Conv2d(
@@ -73,14 +74,12 @@ class AIMv2PatchEmbed(nn.Module):
class AIMv2ViTPreprocessor(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
num_patches = (config.image_size // config.patch_size)**2
num_patches = (config.image_size // config.patch_size) ** 2
self.patchifier = AIMv2PatchEmbed(config)
self.pos_embed = nn.Parameter(
torch.zeros((1, num_patches, config.hidden_size)))
self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
tokens = self.patchifier(x)
@@ -91,9 +90,9 @@ class AIMv2ViTPreprocessor(nn.Module):
class AIMv2Attention(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix: str):
def __init__(
self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -103,7 +102,8 @@ class AIMv2Attention(nn.Module):
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.qkv = QKVParallelLinear(
@@ -126,8 +126,9 @@ class AIMv2Attention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
self.attn = MultiHeadAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv, _ = self.qkv(x)
@@ -139,17 +140,17 @@ class AIMv2Attention(nn.Module):
class AIMv2Block(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix: str):
def __init__(
self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
):
super().__init__()
self.attn = AIMv2Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.attn = AIMv2Attention(
config, quant_config=quant_config, prefix=f"{prefix}.attn"
)
self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = AIMv2SwiGLUFFN(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.mlp = AIMv2SwiGLUFFN(
config, quant_config=quant_config, prefix=f"{prefix}.mlp"
)
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -159,7 +160,6 @@ class AIMv2Block(nn.Module):
class AIMv2Transformer(nn.Module):
def __init__(
self,
config: AIMv2Config,
@@ -170,13 +170,14 @@ class AIMv2Transformer(nn.Module):
):
super().__init__()
self.blocks = nn.ModuleList([
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
for i in range(config.num_hidden_layers)
])
self.blocks = nn.ModuleList(
[
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
for i in range(config.num_hidden_layers)
]
)
if require_post_norm:
self.post_trunk_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.post_trunk_norm = None
@@ -190,29 +191,30 @@ class AIMv2Transformer(nn.Module):
class AIMv2Model(torch.nn.Module):
def __init__(self,
config: AIMv2Config,
quant_config: QuantizationConfig,
*,
require_post_norm: Optional[bool] = None,
prefix: str = ""):
def __init__(
self,
config: AIMv2Config,
quant_config: QuantizationConfig,
*,
require_post_norm: Optional[bool] = None,
prefix: str = "",
):
super().__init__()
self.preprocessor = AIMv2ViTPreprocessor(config)
self.trunk = AIMv2Transformer(config,
quant_config=quant_config,
require_post_norm=require_post_norm,
prefix=f"{prefix}.trunk")
self.trunk = AIMv2Transformer(
config,
quant_config=quant_config,
require_post_norm=require_post_norm,
prefix=f"{prefix}.trunk",
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
x = self.preprocessor(pixel_values)
x = self.trunk(x)
return x
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".fc13", ".fc1", 0),
@@ -223,11 +225,13 @@ class AIMv2Model(torch.nn.Module):
for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if (name.startswith("trunk.post_trunk_norm")
and self.trunk.post_trunk_norm is None):
if (
name.startswith("trunk.post_trunk_norm")
and self.trunk.post_trunk_norm is None
):
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -238,8 +242,7 @@ class AIMv2Model(torch.nn.Module):
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params