Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -14,28 +14,33 @@ from transformers import SiglipVisionConfig
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
|
||||
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
|
||||
resolve_visual_encoder_outputs)
|
||||
from .vision import (
|
||||
VisionEncoderInfo,
|
||||
VisionFeatureSelectStrategy,
|
||||
resolve_visual_encoder_outputs,
|
||||
)
|
||||
|
||||
|
||||
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
return self.get_patch_grid_length()**2
|
||||
return self.get_patch_grid_length() ** 2
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
@@ -50,7 +55,6 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
|
||||
class SiglipVisionEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -66,19 +70,20 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
padding="valid",
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size)**2
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches
|
||||
self.position_embedding = VocabParallelEmbedding(
|
||||
self.num_positions, self.embed_dim)
|
||||
self.num_positions, self.embed_dim
|
||||
)
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(self.num_positions, dtype=torch.int64).expand(
|
||||
(1, -1)),
|
||||
torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
|
||||
width: int) -> torch.Tensor:
|
||||
def interpolate_pos_encoding(
|
||||
self, embeddings: torch.Tensor, height: int, width: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This method is an adapted method for SigLIP (due to SigLIP not having
|
||||
class embedding unlike other ViTs) that allows the model to interpolate
|
||||
@@ -103,8 +108,8 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
height, width = height + 0.1, width + 0.1
|
||||
|
||||
patch_pos_embed = position_embeddings.reshape(
|
||||
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
|
||||
dim)
|
||||
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
@@ -115,33 +120,36 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if (int(height) != patch_pos_embed.shape[-2]
|
||||
or int(width) != patch_pos_embed.shape[-1]):
|
||||
raise ValueError("Width or height does not match with "
|
||||
"the interpolated position embeddings")
|
||||
if (
|
||||
int(height) != patch_pos_embed.shape[-2]
|
||||
or int(width) != patch_pos_embed.shape[-1]
|
||||
):
|
||||
raise ValueError(
|
||||
"Width or height does not match with "
|
||||
"the interpolated position embeddings"
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return patch_pos_embed
|
||||
|
||||
def forward(self,
|
||||
pixel_values: torch.Tensor,
|
||||
interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||
def forward(
|
||||
self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
|
||||
) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.shape
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(
|
||||
dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values.to(dtype=target_dtype)
|
||||
) # shape = [*, width, grid, grid]
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
embeddings += self.interpolate_pos_encoding(
|
||||
embeddings, height, width)
|
||||
embeddings += self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings += self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
class SiglipAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
@@ -155,9 +163,11 @@ class SiglipAttention(nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(f"embed_dim must be divisible by num_heads (got "
|
||||
"`embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got "
|
||||
"`embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
@@ -179,8 +189,9 @@ class SiglipAttention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
self.attn = MultiHeadAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -197,7 +208,6 @@ class SiglipAttention(nn.Module):
|
||||
|
||||
|
||||
class SiglipMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
@@ -209,15 +219,14 @@ class SiglipMLP(nn.Module):
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
# Special handling for BNB and torchao quantization
|
||||
if quant_config and quant_config.get_name() in [
|
||||
"bitsandbytes", "torchao"
|
||||
]:
|
||||
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
|
||||
quantizable = True
|
||||
else:
|
||||
# For other quantization, we require the hidden size to be a
|
||||
# multiple of 64
|
||||
quantizable = (config.hidden_size % 64 == 0
|
||||
and config.intermediate_size % 64 == 0)
|
||||
quantizable = (
|
||||
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
|
||||
)
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
@@ -239,7 +248,6 @@ class SiglipMLP(nn.Module):
|
||||
|
||||
|
||||
class SiglipEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
@@ -255,15 +263,13 @@ class SiglipEncoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -284,7 +290,6 @@ class SiglipEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
class SiglipEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
@@ -301,12 +306,16 @@ class SiglipEncoder(nn.Module):
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
SiglipEncoderLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
SiglipEncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -341,12 +350,12 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||
# TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
|
||||
self.attention = torch.nn.MultiheadAttention(
|
||||
config.hidden_size, config.num_attention_heads, batch_first=True)
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
config.hidden_size, config.num_attention_heads, batch_first=True
|
||||
)
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
|
||||
)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = hidden_state.shape[0]
|
||||
@@ -363,7 +372,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
|
||||
|
||||
class SiglipVisionTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
@@ -399,13 +407,13 @@ class SiglipVisionTransformer(nn.Module):
|
||||
require_post_norm = len(self.encoder.layers) == num_hidden_layers
|
||||
|
||||
if require_post_norm:
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
else:
|
||||
self.post_layernorm = None
|
||||
|
||||
self.use_head = (True if not hasattr(config, "vision_use_head") else
|
||||
config.vision_use_head)
|
||||
self.use_head = (
|
||||
True if not hasattr(config, "vision_use_head") else config.vision_use_head
|
||||
)
|
||||
if self.use_head:
|
||||
self.head = SiglipMultiheadAttentionPoolingHead(
|
||||
config=config,
|
||||
@@ -493,8 +501,7 @@ class SiglipVisionModel(nn.Module):
|
||||
feature_select_strategy=feature_select_strategy,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -507,8 +514,10 @@ class SiglipVisionModel(nn.Module):
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# post_layernorm is optional in SiglipVisionModel
|
||||
if (name.startswith("vision_model.post_layernorm")
|
||||
and self.vision_model.post_layernorm is None):
|
||||
if (
|
||||
name.startswith("vision_model.post_layernorm")
|
||||
and self.vision_model.post_layernorm is None
|
||||
):
|
||||
continue
|
||||
|
||||
# omit layers when num_hidden_layers_override is set
|
||||
@@ -518,21 +527,21 @@ class SiglipVisionModel(nn.Module):
|
||||
continue
|
||||
|
||||
# Check if this is a scale parameter that needs remapping first
|
||||
if name.endswith(
|
||||
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
|
||||
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
|
||||
# Try to remap the scale name first
|
||||
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if remapped_name is not None and remapped_name in params_dict:
|
||||
# Successfully remapped, use the remapped name
|
||||
param = params_dict[remapped_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(remapped_name)
|
||||
continue
|
||||
# If remapping failed, continue with normal processing
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -543,8 +552,7 @@ class SiglipVisionModel(nn.Module):
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
Reference in New Issue
Block a user