support dynamic resolution image encoding for Nemotron Nano VL (#32121)
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
@@ -21,7 +21,11 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.intern_vit import InternVisionEncoder
|
||||
from vllm.model_executor.models.intern_vit import (
|
||||
InternParallelAttention,
|
||||
InternVisionEncoder,
|
||||
InternVisionEncoderLayer,
|
||||
)
|
||||
|
||||
input_dim_t: TypeAlias = int | tuple[int, int]
|
||||
norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor
|
||||
@@ -43,6 +47,15 @@ to_4tuple = _ntuple(4)
|
||||
to_ntuple = _ntuple
|
||||
|
||||
|
||||
def calc_seq_len(size: tuple[int, int], patch_size: int) -> int:
|
||||
h, w = size
|
||||
return (h // patch_size) * (w // patch_size)
|
||||
|
||||
|
||||
def calc_seq_lens(sizes: list[tuple[int, int]], patch_size: int) -> list[int]:
|
||||
return [calc_seq_len(size, patch_size) for size in sizes]
|
||||
|
||||
|
||||
class ClsToken(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -164,15 +177,73 @@ class ViTPatchGenerator(nn.Module):
|
||||
nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
patches = self.embed_patches(x)
|
||||
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
|
||||
patches = self.cls_token(patches)
|
||||
def forward(
|
||||
self, x: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
|
||||
) -> torch.Tensor:
|
||||
if imgs_sizes is not None:
|
||||
patches = self.embedder(x)
|
||||
patches, pos_enc = self.apply_pos_enc_dynamic(
|
||||
patches, imgs_sizes=imgs_sizes
|
||||
)
|
||||
patches = self.cls_token_dynamic(patches, imgs_sizes=imgs_sizes)
|
||||
else:
|
||||
patches = self.embed_patches(x)
|
||||
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
|
||||
patches = self.cls_token(patches)
|
||||
patches = self.patch_normalizer(patches)
|
||||
if self.return_pos_enc:
|
||||
return patches, pos_enc
|
||||
return patches
|
||||
|
||||
def apply_pos_enc_dynamic(
|
||||
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if not self.abs_pos:
|
||||
return patches, None
|
||||
|
||||
current_length = 0
|
||||
pos_enc_list = []
|
||||
|
||||
for size in imgs_sizes:
|
||||
seq_length = calc_seq_len(size, self.patch_size)
|
||||
|
||||
img_patches = patches[:, current_length : current_length + seq_length, :]
|
||||
pos_enc = self.get_pos_enc(patches.shape[0], input_size=size)
|
||||
img_patches_with_pos = img_patches + pos_enc
|
||||
|
||||
patches = torch.cat(
|
||||
[
|
||||
patches[:, :current_length, :],
|
||||
img_patches_with_pos,
|
||||
patches[:, current_length + seq_length :, :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
pos_enc_list.append(pos_enc)
|
||||
current_length += seq_length
|
||||
|
||||
full_pos_enc = torch.cat(pos_enc_list, dim=1) if pos_enc_list else None
|
||||
return patches, full_pos_enc
|
||||
|
||||
def cls_token_dynamic(
|
||||
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
|
||||
) -> torch.Tensor:
|
||||
if not self.cls_token.enabled:
|
||||
return patches
|
||||
|
||||
out = []
|
||||
current_length = 0
|
||||
|
||||
for seq_len in calc_seq_lens(imgs_sizes, self.patch_size):
|
||||
class_token = self.cls_token.token.unsqueeze(0).expand(
|
||||
patches.shape[0], -1, -1
|
||||
)
|
||||
out.append(class_token)
|
||||
out.append(patches[:, current_length : current_length + seq_len, :])
|
||||
current_length += seq_len
|
||||
|
||||
return torch.cat(out, dim=1)
|
||||
|
||||
@property
|
||||
def apply_cls_token(self):
|
||||
return self.cls_token.enabled
|
||||
@@ -406,6 +477,66 @@ class ViTPatchLinear(nn.Linear):
|
||||
self.patch_size = patch_size
|
||||
|
||||
|
||||
class RadioParallelAttention(InternParallelAttention):
|
||||
def forward(
|
||||
self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is None:
|
||||
return super().forward(x)
|
||||
|
||||
B, N, _ = x.shape
|
||||
qkv, _ = self.qkv(x)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
if self.qk_normalization:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
|
||||
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, scale=self.scale
|
||||
)
|
||||
out = out.transpose(1, 2).reshape(B, N, -1)
|
||||
out, _ = self.proj(out)
|
||||
return out
|
||||
|
||||
|
||||
class RadioVisionEncoderLayer(InternVisionEncoderLayer):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, attn_cls=RadioParallelAttention, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
):
|
||||
hidden_states = (
|
||||
hidden_states
|
||||
+ self.attn(self.norm1(hidden_states), attn_mask=attn_mask) * self.ls1
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RadioVisionEncoder(InternVisionEncoder):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, layer_cls=RadioVisionEncoderLayer, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states, attn_mask=attn_mask)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RadioInternVisionModel(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv": ["qkv"],
|
||||
@@ -440,7 +571,7 @@ class RadioInternVisionModel(nn.Module):
|
||||
register_multiple=config.register_multiple,
|
||||
)
|
||||
|
||||
self.encoder = InternVisionEncoder(
|
||||
self.encoder = RadioVisionEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
@@ -459,10 +590,45 @@ class RadioInternVisionModel(nn.Module):
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.FloatTensor:
|
||||
def create_inter_image_attention_mask(
|
||||
self, imgs_sizes: list[tuple[int, int]], device: torch.device
|
||||
) -> torch.Tensor:
|
||||
patch_size = self.patch_generator.patch_size
|
||||
num_skip = self.patch_generator.num_skip
|
||||
|
||||
seq_lens = calc_seq_lens(imgs_sizes, patch_size)
|
||||
patch_counts = [seq_len + num_skip for seq_len in seq_lens]
|
||||
total_patches = sum(patch_counts)
|
||||
|
||||
# Create attention mask - default to False (mask out)
|
||||
mask = torch.zeros(
|
||||
total_patches, total_patches, dtype=torch.bool, device=device
|
||||
)
|
||||
|
||||
# Each image's patches can only attend to patches from the same image
|
||||
start_idx = 0
|
||||
for patch_count in patch_counts:
|
||||
end_idx = start_idx + patch_count
|
||||
# Allow attention within this image's patches
|
||||
mask[start_idx:end_idx, start_idx:end_idx] = True
|
||||
start_idx = end_idx
|
||||
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
imgs_sizes: torch.Tensor | None = None,
|
||||
) -> torch.FloatTensor:
|
||||
assert self.patch_generator is not None
|
||||
hidden_states = self.patch_generator(x)
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
||||
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
|
||||
attn_mask = None
|
||||
if imgs_sizes is not None and len(imgs_sizes) > 1:
|
||||
# Dynamic Resolution
|
||||
attn_mask = self.create_inter_image_attention_mask(
|
||||
imgs_sizes, device=x.device
|
||||
)
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states, attn_mask=attn_mask)
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
@@ -504,9 +670,11 @@ class RadioModel(nn.Module):
|
||||
self,
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
pixel_embeds: torch.Tensor | None = None,
|
||||
*,
|
||||
imgs_sizes: torch.Tensor | None = None,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
y = self.model(pixel_values)
|
||||
return self._extract_final(y)
|
||||
y = self.model(pixel_values, imgs_sizes=imgs_sizes)
|
||||
return self._extract_final(y, imgs_sizes=imgs_sizes)
|
||||
|
||||
def load_weights(self, weights) -> set[str]:
|
||||
loaded_params: set[str] = set()
|
||||
@@ -558,16 +726,32 @@ class RadioModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
def _extract_final(
|
||||
self, y: torch.Tensor
|
||||
self, y: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# Remove CLS + REGISTERS tokens
|
||||
patch_gen = getattr(self.model, "patch_generator", None)
|
||||
if patch_gen is not None:
|
||||
all_summary = y[:, : patch_gen.num_cls_tokens]
|
||||
if self.summary_idxs is not None:
|
||||
bb_summary = all_summary[:, self.summary_idxs]
|
||||
else:
|
||||
bb_summary = all_summary
|
||||
all_feat = y[:, patch_gen.num_skip :]
|
||||
num_skip = self.model.patch_generator.num_skip
|
||||
patch_size = self.model.patch_generator.patch_size
|
||||
num_cls_tokens = self.model.patch_generator.num_cls_tokens
|
||||
if imgs_sizes is None:
|
||||
all_summary = y[:, :num_cls_tokens]
|
||||
all_feat = y[:, num_skip:]
|
||||
else:
|
||||
all_patches = []
|
||||
summaries = []
|
||||
current_pos = 0
|
||||
for num_patches in calc_seq_lens(imgs_sizes, patch_size):
|
||||
patches = y[
|
||||
:, current_pos + num_skip : current_pos + num_skip + num_patches, :
|
||||
]
|
||||
all_patches.append(patches)
|
||||
summary = y[:, current_pos : current_pos + num_cls_tokens, :]
|
||||
summaries.append(summary)
|
||||
current_pos += num_skip + num_patches
|
||||
all_summary = torch.cat(summaries, dim=1)
|
||||
all_feat = torch.cat(all_patches, dim=1)
|
||||
|
||||
if self.summary_idxs is not None:
|
||||
bb_summary = all_summary[:, self.summary_idxs]
|
||||
else:
|
||||
bb_summary = all_summary
|
||||
return bb_summary.flatten(1), all_feat
|
||||
|
||||
Reference in New Issue
Block a user