Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -17,10 +17,13 @@ from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearBase,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -28,23 +31,20 @@ from .vision import get_vit_attn_backend
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
inv_freq = 1.0 / (theta
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Siglip2VisionEmbeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
@@ -58,15 +58,13 @@ class Siglip2VisionEmbeddings(nn.Module):
# siglip2 naflex
if self.num_patches > 0:
self.patch_embedding = ReplicatedLinear(
input_size=config.num_channels * self.patch_size *
self.patch_size,
input_size=config.num_channels * self.patch_size * self.patch_size,
output_size=self.embed_dim,
return_bias=False,
)
if self.preserve_original_pe:
self.position_embedding_size = int(self.num_patches**0.5)
self.position_embedding = nn.Embedding(self.num_patches,
self.embed_dim)
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
else:
self.patch_embedding = nn.Conv2d(
@@ -77,15 +75,15 @@ class Siglip2VisionEmbeddings(nn.Module):
padding="valid",
)
if self.preserve_original_pe:
self.num_patches = (self.image_size // self.patch_size)**2
self.position_embedding_size = (self.image_size //
self.patch_size)
self.position_embedding = nn.Embedding(self.num_patches,
self.embed_dim)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.position_embedding_size = self.image_size // self.patch_size
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
def forward(self,
pixel_values: torch.FloatTensor,
grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor:
def forward(
self,
pixel_values: torch.FloatTensor,
grid_thws: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
"""
Args:
pixel_values (`torch.FloatTensor`):
@@ -100,36 +98,48 @@ class Siglip2VisionEmbeddings(nn.Module):
# Apply patch embeddings to already patchified pixel values
target_dtype = self.patch_embedding.weight.dtype
if isinstance(self.patch_embedding, LinearBase):
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype))
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
elif isinstance(self.patch_embedding, nn.Conv2d):
pixel_values = pixel_values.view(
-1, self.config.num_channels * self.config.temporal_patch_size,
self.patch_size, self.patch_size)
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype))
-1,
self.config.num_channels * self.config.temporal_patch_size,
self.patch_size,
self.patch_size,
)
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
patch_embeds = patch_embeds.reshape(-1, self.embed_dim)
if self.preserve_original_pe:
assert grid_thws is not None
pos_embed_new = torch.zeros_like(patch_embeds)
positional_embeddings = self.position_embedding.weight.reshape(
self.position_embedding_size, self.position_embedding_size,
-1).unsqueeze(0).permute(0, 3, 1, 2)
positional_embeddings = (
self.position_embedding.weight.reshape(
self.position_embedding_size, self.position_embedding_size, -1
)
.unsqueeze(0)
.permute(0, 3, 1, 2)
)
cnt = 0
for t, h, w in grid_thws:
volume = t * h * w
pe = F.interpolate(positional_embeddings,
size=(h, w),
mode='bicubic',
align_corners=False)
pe = F.interpolate(
positional_embeddings,
size=(h, w),
mode="bicubic",
align_corners=False,
)
pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
pe = pe[0].repeat(t, 1)
pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride,
w // self.hidden_stride, self.hidden_stride,
-1)
pe = pe.reshape(
t,
h // self.hidden_stride,
self.hidden_stride,
w // self.hidden_stride,
self.hidden_stride,
-1,
)
pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
pos_embed_new[cnt:cnt + volume] = pe
pos_embed_new[cnt : cnt + volume] = pe
cnt += volume
patch_embeds = patch_embeds + pos_embed_new
@@ -143,9 +153,9 @@ def rotate_half(x, interleaved=False):
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1),
"... d two -> ... (d two)",
two=2)
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
@@ -156,15 +166,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos +
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
@@ -181,13 +191,12 @@ def apply_rotary_pos_emb(
sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend:
from flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb
else:
apply_rotary_emb_func = apply_rotary_emb_torch
q_embed = apply_rotary_emb_func(q.float(), cos.float(),
sin.float()).type_as(q)
k_embed = apply_rotary_emb_func(k.float(), cos.float(),
sin.float()).type_as(k)
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed
@@ -210,7 +219,8 @@ class Siglip2Attention(nn.Module):
raise ValueError(
f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
@@ -231,37 +241,41 @@ class Siglip2Attention(nn.Module):
prefix=f"{prefix}.out_proj",
)
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.use_rope = config.use_rope
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
head_size=self.head_dim, dtype=torch.get_default_dtype())
head_size=self.head_dim, dtype=torch.get_default_dtype()
)
self.use_upstream_fa = False
self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
_Backend.ROCM_AITER_FA
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.ROCM_AITER_FA,
}:
self.attn_backend = _Backend.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
}
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@@ -270,26 +284,27 @@ class Siglip2Attention(nn.Module):
qkv_states, _ = self.qkv_proj(hidden_states)
queries, keys, values = qkv_states.chunk(3, dim=-1)
queries = queries.view(seq_length, self.num_heads_per_partition,
self.head_dim)
keys = keys.view(seq_length, self.num_heads_per_partition,
self.head_dim)
values = values.view(seq_length, self.num_heads_per_partition,
self.head_dim)
queries = queries.view(seq_length, self.num_heads_per_partition, self.head_dim)
keys = keys.view(seq_length, self.num_heads_per_partition, self.head_dim)
values = values.view(seq_length, self.num_heads_per_partition, self.head_dim)
if self.use_rope:
cos, sin = position_embeddings
queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0),
keys.unsqueeze(0), cos, sin,
self.is_flash_attn_backend)
queries, keys = apply_rotary_pos_emb(
queries.unsqueeze(0),
keys.unsqueeze(0),
cos,
sin,
self.is_flash_attn_backend,
)
queries = queries.squeeze(0)
keys = keys.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
if self.is_flash_attn_backend:
attn_output = self.flash_attn_varlen_func(
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
max_seqlen).reshape(seq_length, -1)
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
).reshape(seq_length, -1)
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
batch_size = cu_seqlens.shape[0] - 1
@@ -308,13 +323,9 @@ class Siglip2Attention(nn.Module):
# (1, num_heads, seq_len, head_dim)
q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
output_i = output_i.transpose(1, 2).reshape(
end_idx - start_idx, -1)
output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1)
outputs.append(output_i)
attn_output = torch.cat(outputs, dim=0)
@@ -323,7 +334,6 @@ class Siglip2Attention(nn.Module):
class Siglip2MLP(nn.Module):
def __init__(
self,
config: Siglip2VisionConfig,
@@ -357,7 +367,6 @@ class Siglip2MLP(nn.Module):
class Siglip2EncoderLayer(nn.Module):
def __init__(
self,
config: Siglip2VisionConfig,
@@ -367,21 +376,27 @@ class Siglip2EncoderLayer(nn.Module):
):
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.self_attn = Siglip2Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.self_attn = Siglip2Attention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(
config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> tuple[torch.FloatTensor]:
"""
Args:
hidden_states: Input tensor of shape (batch, seq_len, embed_dim).
@@ -391,9 +406,11 @@ class Siglip2EncoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
residual = hidden_states
@@ -405,7 +422,7 @@ class Siglip2EncoderLayer(nn.Module):
class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers`
Transformer encoder consisting of `config.num_hidden_layers`
self attention layers. Each layer is a [`Siglip2EncoderLayer`].
Args:
@@ -421,16 +438,21 @@ class Siglip2Encoder(nn.Module):
):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Siglip2EncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}",
use_data_parallel=use_data_parallel)
for idx in range(config.num_hidden_layers)
])
self.layers = nn.ModuleList(
[
Siglip2EncoderLayer(
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}",
use_data_parallel=use_data_parallel,
)
for idx in range(config.num_hidden_layers)
]
)
self.rotary_pos_emb = VisionRotaryEmbedding(
config.hidden_size // config.num_attention_heads // 2)
config.hidden_size // config.num_attention_heads // 2
)
self.patch_size = config.patch_size
self.hidden_stride = config.hidden_stride
self.window_size = config.window_size
@@ -439,7 +461,7 @@ class Siglip2Encoder(nn.Module):
self.fullatt_block_indexes = None
else:
self.fullatt_block_indexes = [
int(i) for i in config.fullatt_block_indexes.split('|')
int(i) for i in config.fullatt_block_indexes.split("|")
]
# copied from qwen2.5_vl
@@ -465,8 +487,7 @@ class Siglip2Encoder(nn.Module):
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
@@ -478,8 +499,9 @@ class Siglip2Encoder(nn.Module):
cu_window_seqlens: list = [0]
window_index_id = 0
# patch (after merge) number in each window
vit_merger_window_size = (self.window_size // self.hidden_stride //
self.patch_size)
vit_merger_window_size = (
self.window_size // self.hidden_stride // self.patch_size
)
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
@@ -487,7 +509,8 @@ class Siglip2Encoder(nn.Module):
grid_w // self.hidden_stride,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
grid_t, llm_grid_h, llm_grid_w)
grid_t, llm_grid_h, llm_grid_w
)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
@@ -510,8 +533,9 @@ class Siglip2Encoder(nn.Module):
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
cu_seqlens_tmp = (
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
)
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
@@ -525,10 +549,10 @@ class Siglip2Encoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
inputs_embeds: Input tensor of shape
inputs_embeds: Input tensor of shape
(batch_size, sequence_length, hidden_size).
Embedded representation of the input tokens.
grid_thws: Grid tensor of shape (num_patches, 3)
grid_thws: Grid tensor of shape (num_patches, 3)
containing grid dimensions.
Whether or not to return a [`~utils.ModelOutput`] instead of
a plain tuple.
@@ -544,11 +568,13 @@ class Siglip2Encoder(nn.Module):
seq_len, _ = inputs_embeds.size()
inputs_embeds = inputs_embeds.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
)
inputs_embeds = inputs_embeds[window_index, :, :]
inputs_embeds = inputs_embeds.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
@@ -572,23 +598,21 @@ class Siglip2Encoder(nn.Module):
hidden_states = inputs_embeds
for index, block in enumerate(self.layers):
if (not self.fullatt_block_indexes
or index in self.fullatt_block_indexes):
if not self.fullatt_block_indexes or index in self.fullatt_block_indexes:
cu_seqlens_tmp = cu_seqlens
else:
cu_seqlens_tmp = cu_window_seqlens
hidden_states = block(hidden_states, cu_seqlens_tmp,
position_embeddings)
hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings)
hidden_states = hidden_states.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
)
hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
return hidden_states
class Siglip2VisionTransformer(nn.Module):
def __init__(
self,
config: Siglip2VisionConfig,
@@ -601,12 +625,13 @@ class Siglip2VisionTransformer(nn.Module):
embed_dim = config.hidden_size
self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
self.encoder = Siglip2Encoder(
config,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel,
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(
self,
@@ -627,7 +652,6 @@ class Siglip2VisionTransformer(nn.Module):
class Siglip2NavitModel(torch.nn.Module):
def __init__(
self,
config: Siglip2VisionConfig,
@@ -641,7 +665,8 @@ class Siglip2NavitModel(torch.nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.vision_model",
use_data_parallel=use_data_parallel)
use_data_parallel=use_data_parallel,
)
def forward(
self,
@@ -653,8 +678,7 @@ class Siglip2NavitModel(torch.nn.Module):
grid_thws=grid_thws,
)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -665,7 +689,7 @@ class Siglip2NavitModel(torch.nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -676,8 +700,7 @@ class Siglip2NavitModel(torch.nn.Module):
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params