Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -17,10 +17,13 @@ from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
@@ -28,23 +31,20 @@ from .vision import get_vit_attn_backend
|
||||
|
||||
|
||||
class VisionRotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (theta
|
||||
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, seqlen: int) -> torch.Tensor:
|
||||
seq = torch.arange(seqlen,
|
||||
device=self.inv_freq.device,
|
||||
dtype=self.inv_freq.dtype)
|
||||
seq = torch.arange(
|
||||
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||
)
|
||||
freqs = torch.outer(seq, self.inv_freq)
|
||||
return freqs
|
||||
|
||||
|
||||
class Siglip2VisionEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -58,15 +58,13 @@ class Siglip2VisionEmbeddings(nn.Module):
|
||||
# siglip2 naflex
|
||||
if self.num_patches > 0:
|
||||
self.patch_embedding = ReplicatedLinear(
|
||||
input_size=config.num_channels * self.patch_size *
|
||||
self.patch_size,
|
||||
input_size=config.num_channels * self.patch_size * self.patch_size,
|
||||
output_size=self.embed_dim,
|
||||
return_bias=False,
|
||||
)
|
||||
if self.preserve_original_pe:
|
||||
self.position_embedding_size = int(self.num_patches**0.5)
|
||||
self.position_embedding = nn.Embedding(self.num_patches,
|
||||
self.embed_dim)
|
||||
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
|
||||
|
||||
else:
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
@@ -77,15 +75,15 @@ class Siglip2VisionEmbeddings(nn.Module):
|
||||
padding="valid",
|
||||
)
|
||||
if self.preserve_original_pe:
|
||||
self.num_patches = (self.image_size // self.patch_size)**2
|
||||
self.position_embedding_size = (self.image_size //
|
||||
self.patch_size)
|
||||
self.position_embedding = nn.Embedding(self.num_patches,
|
||||
self.embed_dim)
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.position_embedding_size = self.image_size // self.patch_size
|
||||
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
|
||||
|
||||
def forward(self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
grid_thws: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor`):
|
||||
@@ -100,36 +98,48 @@ class Siglip2VisionEmbeddings(nn.Module):
|
||||
# Apply patch embeddings to already patchified pixel values
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
if isinstance(self.patch_embedding, LinearBase):
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values.to(dtype=target_dtype))
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
||||
elif isinstance(self.patch_embedding, nn.Conv2d):
|
||||
pixel_values = pixel_values.view(
|
||||
-1, self.config.num_channels * self.config.temporal_patch_size,
|
||||
self.patch_size, self.patch_size)
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values.to(dtype=target_dtype))
|
||||
-1,
|
||||
self.config.num_channels * self.config.temporal_patch_size,
|
||||
self.patch_size,
|
||||
self.patch_size,
|
||||
)
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
||||
patch_embeds = patch_embeds.reshape(-1, self.embed_dim)
|
||||
|
||||
if self.preserve_original_pe:
|
||||
assert grid_thws is not None
|
||||
pos_embed_new = torch.zeros_like(patch_embeds)
|
||||
positional_embeddings = self.position_embedding.weight.reshape(
|
||||
self.position_embedding_size, self.position_embedding_size,
|
||||
-1).unsqueeze(0).permute(0, 3, 1, 2)
|
||||
positional_embeddings = (
|
||||
self.position_embedding.weight.reshape(
|
||||
self.position_embedding_size, self.position_embedding_size, -1
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.permute(0, 3, 1, 2)
|
||||
)
|
||||
cnt = 0
|
||||
for t, h, w in grid_thws:
|
||||
volume = t * h * w
|
||||
pe = F.interpolate(positional_embeddings,
|
||||
size=(h, w),
|
||||
mode='bicubic',
|
||||
align_corners=False)
|
||||
pe = F.interpolate(
|
||||
positional_embeddings,
|
||||
size=(h, w),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
|
||||
pe = pe[0].repeat(t, 1)
|
||||
pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride,
|
||||
w // self.hidden_stride, self.hidden_stride,
|
||||
-1)
|
||||
pe = pe.reshape(
|
||||
t,
|
||||
h // self.hidden_stride,
|
||||
self.hidden_stride,
|
||||
w // self.hidden_stride,
|
||||
self.hidden_stride,
|
||||
-1,
|
||||
)
|
||||
pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
|
||||
pos_embed_new[cnt:cnt + volume] = pe
|
||||
pos_embed_new[cnt : cnt + volume] = pe
|
||||
cnt += volume
|
||||
patch_embeds = patch_embeds + pos_embed_new
|
||||
|
||||
@@ -143,9 +153,9 @@ def rotate_half(x, interleaved=False):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(torch.stack((-x2, x1), dim=-1),
|
||||
"... d two -> ... (d two)",
|
||||
two=2)
|
||||
return rearrange(
|
||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
@@ -156,15 +166,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos,
|
||||
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
sin = repeat(
|
||||
sin,
|
||||
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
return torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos +
|
||||
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
|
||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
@@ -181,13 +191,12 @@ def apply_rotary_pos_emb(
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
if is_flash_attn_backend:
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
apply_rotary_emb_func = apply_rotary_emb
|
||||
else:
|
||||
apply_rotary_emb_func = apply_rotary_emb_torch
|
||||
q_embed = apply_rotary_emb_func(q.float(), cos.float(),
|
||||
sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb_func(k.float(), cos.float(),
|
||||
sin.float()).type_as(k)
|
||||
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
@@ -210,7 +219,8 @@ class Siglip2Attention(nn.Module):
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
@@ -231,37 +241,41 @@ class Siglip2Attention(nn.Module):
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.tp_size = (
|
||||
1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
self.use_rope = config.use_rope
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.head_dim, dtype=torch.get_default_dtype())
|
||||
head_size=self.head_dim, dtype=torch.get_default_dtype()
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func \
|
||||
= maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
|
||||
_Backend.ROCM_AITER_FA
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
}:
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
position_embeddings: Optional[tuple[torch.Tensor,
|
||||
torch.Tensor]] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@@ -270,26 +284,27 @@ class Siglip2Attention(nn.Module):
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
queries, keys, values = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
queries = queries.view(seq_length, self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
keys = keys.view(seq_length, self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
values = values.view(seq_length, self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
queries = queries.view(seq_length, self.num_heads_per_partition, self.head_dim)
|
||||
keys = keys.view(seq_length, self.num_heads_per_partition, self.head_dim)
|
||||
values = values.view(seq_length, self.num_heads_per_partition, self.head_dim)
|
||||
|
||||
if self.use_rope:
|
||||
cos, sin = position_embeddings
|
||||
queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0),
|
||||
keys.unsqueeze(0), cos, sin,
|
||||
self.is_flash_attn_backend)
|
||||
queries, keys = apply_rotary_pos_emb(
|
||||
queries.unsqueeze(0),
|
||||
keys.unsqueeze(0),
|
||||
cos,
|
||||
sin,
|
||||
self.is_flash_attn_backend,
|
||||
)
|
||||
queries = queries.squeeze(0)
|
||||
keys = keys.squeeze(0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
if self.is_flash_attn_backend:
|
||||
attn_output = self.flash_attn_varlen_func(
|
||||
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
|
||||
max_seqlen).reshape(seq_length, -1)
|
||||
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
|
||||
).reshape(seq_length, -1)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
batch_size = cu_seqlens.shape[0] - 1
|
||||
@@ -308,13 +323,9 @@ class Siglip2Attention(nn.Module):
|
||||
# (1, num_heads, seq_len, head_dim)
|
||||
q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
|
||||
|
||||
output_i = F.scaled_dot_product_attention(q_i,
|
||||
k_i,
|
||||
v_i,
|
||||
dropout_p=0.0)
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
|
||||
output_i = output_i.transpose(1, 2).reshape(
|
||||
end_idx - start_idx, -1)
|
||||
output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1)
|
||||
outputs.append(output_i)
|
||||
|
||||
attn_output = torch.cat(outputs, dim=0)
|
||||
@@ -323,7 +334,6 @@ class Siglip2Attention(nn.Module):
|
||||
|
||||
|
||||
class Siglip2MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
@@ -357,7 +367,6 @@ class Siglip2MLP(nn.Module):
|
||||
|
||||
|
||||
class Siglip2EncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
@@ -367,21 +376,27 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.self_attn = Siglip2Attention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = Siglip2MLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.self_attn = Siglip2Attention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Siglip2MLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
) -> tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: Input tensor of shape (batch, seq_len, embed_dim).
|
||||
@@ -391,9 +406,11 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
@@ -405,7 +422,7 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
|
||||
class Siglip2Encoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers`
|
||||
Transformer encoder consisting of `config.num_hidden_layers`
|
||||
self attention layers. Each layer is a [`Siglip2EncoderLayer`].
|
||||
|
||||
Args:
|
||||
@@ -421,16 +438,21 @@ class Siglip2Encoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([
|
||||
Siglip2EncoderLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Siglip2EncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.rotary_pos_emb = VisionRotaryEmbedding(
|
||||
config.hidden_size // config.num_attention_heads // 2)
|
||||
config.hidden_size // config.num_attention_heads // 2
|
||||
)
|
||||
self.patch_size = config.patch_size
|
||||
self.hidden_stride = config.hidden_stride
|
||||
self.window_size = config.window_size
|
||||
@@ -439,7 +461,7 @@ class Siglip2Encoder(nn.Module):
|
||||
self.fullatt_block_indexes = None
|
||||
else:
|
||||
self.fullatt_block_indexes = [
|
||||
int(i) for i in config.fullatt_block_indexes.split('|')
|
||||
int(i) for i in config.fullatt_block_indexes.split("|")
|
||||
]
|
||||
|
||||
# copied from qwen2.5_vl
|
||||
@@ -465,8 +487,7 @@ class Siglip2Encoder(nn.Module):
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||
wpos_ids = wpos_ids.flatten()
|
||||
pos_ids.append(
|
||||
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
@@ -478,8 +499,9 @@ class Siglip2Encoder(nn.Module):
|
||||
cu_window_seqlens: list = [0]
|
||||
window_index_id = 0
|
||||
# patch (after merge) number in each window
|
||||
vit_merger_window_size = (self.window_size // self.hidden_stride //
|
||||
self.patch_size)
|
||||
vit_merger_window_size = (
|
||||
self.window_size // self.hidden_stride // self.patch_size
|
||||
)
|
||||
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h, llm_grid_w = (
|
||||
@@ -487,7 +509,8 @@ class Siglip2Encoder(nn.Module):
|
||||
grid_w // self.hidden_stride,
|
||||
)
|
||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||
grid_t, llm_grid_h, llm_grid_w)
|
||||
grid_t, llm_grid_h, llm_grid_w
|
||||
)
|
||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||
@@ -510,8 +533,9 @@ class Siglip2Encoder(nn.Module):
|
||||
index_padded = index_padded.reshape(-1)
|
||||
index_new = index_padded[index_padded != -100]
|
||||
window_index.append(index_new + window_index_id)
|
||||
cu_seqlens_tmp = seqlens.cumsum(
|
||||
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||
cu_seqlens_tmp = (
|
||||
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||
)
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
@@ -525,10 +549,10 @@ class Siglip2Encoder(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds: Input tensor of shape
|
||||
inputs_embeds: Input tensor of shape
|
||||
(batch_size, sequence_length, hidden_size).
|
||||
Embedded representation of the input tokens.
|
||||
grid_thws: Grid tensor of shape (num_patches, 3)
|
||||
grid_thws: Grid tensor of shape (num_patches, 3)
|
||||
containing grid dimensions.
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of
|
||||
a plain tuple.
|
||||
@@ -544,11 +568,13 @@ class Siglip2Encoder(nn.Module):
|
||||
|
||||
seq_len, _ = inputs_embeds.size()
|
||||
inputs_embeds = inputs_embeds.reshape(
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
||||
)
|
||||
inputs_embeds = inputs_embeds[window_index, :, :]
|
||||
inputs_embeds = inputs_embeds.reshape(seq_len, -1)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
||||
)
|
||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
@@ -572,23 +598,21 @@ class Siglip2Encoder(nn.Module):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for index, block in enumerate(self.layers):
|
||||
if (not self.fullatt_block_indexes
|
||||
or index in self.fullatt_block_indexes):
|
||||
if not self.fullatt_block_indexes or index in self.fullatt_block_indexes:
|
||||
cu_seqlens_tmp = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_tmp = cu_window_seqlens
|
||||
hidden_states = block(hidden_states, cu_seqlens_tmp,
|
||||
position_embeddings)
|
||||
hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
||||
)
|
||||
hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Siglip2VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
@@ -601,12 +625,13 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = Siglip2VisionEmbeddings(config)
|
||||
self.encoder = Siglip2Encoder(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.encoder = Siglip2Encoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -627,7 +652,6 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
|
||||
|
||||
class Siglip2NavitModel(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
@@ -641,7 +665,8 @@ class Siglip2NavitModel(torch.nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
use_data_parallel=use_data_parallel)
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -653,8 +678,7 @@ class Siglip2NavitModel(torch.nn.Module):
|
||||
grid_thws=grid_thws,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -665,7 +689,7 @@ class Siglip2NavitModel(torch.nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -676,8 +700,7 @@ class Siglip2NavitModel(torch.nn.Module):
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
Reference in New Issue
Block a user