Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -14,18 +14,20 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.transformers_utils.configs.ovis import AIMv2Config
|
||||
|
||||
|
||||
class AIMv2SwiGLUFFN(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
|
||||
prefix: str):
|
||||
def __init__(
|
||||
self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
|
||||
):
|
||||
super().__init__()
|
||||
hidden_features = config.intermediate_size
|
||||
in_features = config.hidden_size
|
||||
@@ -55,7 +57,6 @@ class AIMv2SwiGLUFFN(nn.Module):
|
||||
|
||||
|
||||
class AIMv2PatchEmbed(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(
|
||||
@@ -73,14 +74,12 @@ class AIMv2PatchEmbed(nn.Module):
|
||||
|
||||
|
||||
class AIMv2ViTPreprocessor(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config):
|
||||
super().__init__()
|
||||
num_patches = (config.image_size // config.patch_size)**2
|
||||
num_patches = (config.image_size // config.patch_size) ** 2
|
||||
|
||||
self.patchifier = AIMv2PatchEmbed(config)
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros((1, num_patches, config.hidden_size)))
|
||||
self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
tokens = self.patchifier(x)
|
||||
@@ -91,9 +90,9 @@ class AIMv2ViTPreprocessor(nn.Module):
|
||||
|
||||
|
||||
class AIMv2Attention(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
|
||||
prefix: str):
|
||||
def __init__(
|
||||
self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -103,7 +102,8 @@ class AIMv2Attention(nn.Module):
|
||||
raise ValueError(
|
||||
"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
@@ -126,8 +126,9 @@ class AIMv2Attention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
self.attn = MultiHeadAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
qkv, _ = self.qkv(x)
|
||||
@@ -139,17 +140,17 @@ class AIMv2Attention(nn.Module):
|
||||
|
||||
|
||||
class AIMv2Block(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
|
||||
prefix: str):
|
||||
def __init__(
|
||||
self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
|
||||
):
|
||||
super().__init__()
|
||||
self.attn = AIMv2Attention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn = AIMv2Attention(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.attn"
|
||||
)
|
||||
self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.mlp = AIMv2SwiGLUFFN(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.mlp = AIMv2SwiGLUFFN(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.mlp"
|
||||
)
|
||||
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -159,7 +160,6 @@ class AIMv2Block(nn.Module):
|
||||
|
||||
|
||||
class AIMv2Transformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AIMv2Config,
|
||||
@@ -170,13 +170,14 @@ class AIMv2Transformer(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
if require_post_norm:
|
||||
self.post_trunk_norm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.post_trunk_norm = None
|
||||
|
||||
@@ -190,29 +191,30 @@ class AIMv2Transformer(nn.Module):
|
||||
|
||||
|
||||
class AIMv2Model(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: AIMv2Config,
|
||||
quant_config: QuantizationConfig,
|
||||
*,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
config: AIMv2Config,
|
||||
quant_config: QuantizationConfig,
|
||||
*,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.preprocessor = AIMv2ViTPreprocessor(config)
|
||||
self.trunk = AIMv2Transformer(config,
|
||||
quant_config=quant_config,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.trunk")
|
||||
self.trunk = AIMv2Transformer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.trunk",
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
x = self.preprocessor(pixel_values)
|
||||
x = self.trunk(x)
|
||||
|
||||
return x
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".fc13", ".fc1", 0),
|
||||
@@ -223,11 +225,13 @@ class AIMv2Model(torch.nn.Module):
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# post_layernorm is optional in SiglipVisionModel
|
||||
if (name.startswith("trunk.post_trunk_norm")
|
||||
and self.trunk.post_trunk_norm is None):
|
||||
if (
|
||||
name.startswith("trunk.post_trunk_norm")
|
||||
and self.trunk.post_trunk_norm is None
|
||||
):
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -238,8 +242,7 @@ class AIMv2Model(torch.nn.Module):
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
Reference in New Issue
Block a user