support dynamic resolution image encoding for Nemotron Nano VL (#32121)

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
Netanel Haber
2026-01-19 20:15:58 +02:00
committed by GitHub
parent 2636d76257
commit cd3ac5b797
3 changed files with 754 additions and 163 deletions

View File

@@ -21,7 +21,11 @@ from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionEncoder
from vllm.model_executor.models.intern_vit import (
InternParallelAttention,
InternVisionEncoder,
InternVisionEncoderLayer,
)
input_dim_t: TypeAlias = int | tuple[int, int]
norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor
@@ -43,6 +47,15 @@ to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def calc_seq_len(size: tuple[int, int], patch_size: int) -> int:
h, w = size
return (h // patch_size) * (w // patch_size)
def calc_seq_lens(sizes: list[tuple[int, int]], patch_size: int) -> list[int]:
return [calc_seq_len(size, patch_size) for size in sizes]
class ClsToken(nn.Module):
def __init__(
self,
@@ -164,15 +177,73 @@ class ViTPatchGenerator(nn.Module):
nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
patches = self.embed_patches(x)
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
patches = self.cls_token(patches)
def forward(
self, x: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
) -> torch.Tensor:
if imgs_sizes is not None:
patches = self.embedder(x)
patches, pos_enc = self.apply_pos_enc_dynamic(
patches, imgs_sizes=imgs_sizes
)
patches = self.cls_token_dynamic(patches, imgs_sizes=imgs_sizes)
else:
patches = self.embed_patches(x)
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
patches = self.cls_token(patches)
patches = self.patch_normalizer(patches)
if self.return_pos_enc:
return patches, pos_enc
return patches
def apply_pos_enc_dynamic(
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
) -> tuple[torch.Tensor, torch.Tensor | None]:
if not self.abs_pos:
return patches, None
current_length = 0
pos_enc_list = []
for size in imgs_sizes:
seq_length = calc_seq_len(size, self.patch_size)
img_patches = patches[:, current_length : current_length + seq_length, :]
pos_enc = self.get_pos_enc(patches.shape[0], input_size=size)
img_patches_with_pos = img_patches + pos_enc
patches = torch.cat(
[
patches[:, :current_length, :],
img_patches_with_pos,
patches[:, current_length + seq_length :, :],
],
dim=1,
)
pos_enc_list.append(pos_enc)
current_length += seq_length
full_pos_enc = torch.cat(pos_enc_list, dim=1) if pos_enc_list else None
return patches, full_pos_enc
def cls_token_dynamic(
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
) -> torch.Tensor:
if not self.cls_token.enabled:
return patches
out = []
current_length = 0
for seq_len in calc_seq_lens(imgs_sizes, self.patch_size):
class_token = self.cls_token.token.unsqueeze(0).expand(
patches.shape[0], -1, -1
)
out.append(class_token)
out.append(patches[:, current_length : current_length + seq_len, :])
current_length += seq_len
return torch.cat(out, dim=1)
@property
def apply_cls_token(self):
return self.cls_token.enabled
@@ -406,6 +477,66 @@ class ViTPatchLinear(nn.Linear):
self.patch_size = patch_size
class RadioParallelAttention(InternParallelAttention):
def forward(
self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
) -> torch.Tensor:
if attn_mask is None:
return super().forward(x)
B, N, _ = x.shape
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, scale=self.scale
)
out = out.transpose(1, 2).reshape(B, N, -1)
out, _ = self.proj(out)
return out
class RadioVisionEncoderLayer(InternVisionEncoderLayer):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, attn_cls=RadioParallelAttention, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attn_mask: torch.Tensor | None = None,
):
hidden_states = (
hidden_states
+ self.attn(self.norm1(hidden_states), attn_mask=attn_mask) * self.ls1
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2
return hidden_states
class RadioVisionEncoder(InternVisionEncoder):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, layer_cls=RadioVisionEncoderLayer, **kwargs)
def forward(
self,
inputs_embeds: torch.Tensor,
attn_mask: torch.Tensor | None = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, attn_mask=attn_mask)
return hidden_states
class RadioInternVisionModel(nn.Module):
packed_modules_mapping = {
"qkv": ["qkv"],
@@ -440,7 +571,7 @@ class RadioInternVisionModel(nn.Module):
register_multiple=config.register_multiple,
)
self.encoder = InternVisionEncoder(
self.encoder = RadioVisionEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
@@ -459,10 +590,45 @@ class RadioInternVisionModel(nn.Module):
def get_input_embeddings(self):
return self.embeddings
def forward(self, x: torch.Tensor) -> torch.FloatTensor:
def create_inter_image_attention_mask(
self, imgs_sizes: list[tuple[int, int]], device: torch.device
) -> torch.Tensor:
patch_size = self.patch_generator.patch_size
num_skip = self.patch_generator.num_skip
seq_lens = calc_seq_lens(imgs_sizes, patch_size)
patch_counts = [seq_len + num_skip for seq_len in seq_lens]
total_patches = sum(patch_counts)
# Create attention mask - default to False (mask out)
mask = torch.zeros(
total_patches, total_patches, dtype=torch.bool, device=device
)
# Each image's patches can only attend to patches from the same image
start_idx = 0
for patch_count in patch_counts:
end_idx = start_idx + patch_count
# Allow attention within this image's patches
mask[start_idx:end_idx, start_idx:end_idx] = True
start_idx = end_idx
return mask
def forward(
self,
x: torch.Tensor,
imgs_sizes: torch.Tensor | None = None,
) -> torch.FloatTensor:
assert self.patch_generator is not None
hidden_states = self.patch_generator(x)
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
attn_mask = None
if imgs_sizes is not None and len(imgs_sizes) > 1:
# Dynamic Resolution
attn_mask = self.create_inter_image_attention_mask(
imgs_sizes, device=x.device
)
encoder_outputs = self.encoder(inputs_embeds=hidden_states, attn_mask=attn_mask)
return encoder_outputs
@@ -504,9 +670,11 @@ class RadioModel(nn.Module):
self,
pixel_values: torch.Tensor | None = None,
pixel_embeds: torch.Tensor | None = None,
*,
imgs_sizes: torch.Tensor | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
y = self.model(pixel_values)
return self._extract_final(y)
y = self.model(pixel_values, imgs_sizes=imgs_sizes)
return self._extract_final(y, imgs_sizes=imgs_sizes)
def load_weights(self, weights) -> set[str]:
loaded_params: set[str] = set()
@@ -558,16 +726,32 @@ class RadioModel(nn.Module):
return loaded_params
def _extract_final(
self, y: torch.Tensor
self, y: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# Remove CLS + REGISTERS tokens
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
all_summary = y[:, : patch_gen.num_cls_tokens]
if self.summary_idxs is not None:
bb_summary = all_summary[:, self.summary_idxs]
else:
bb_summary = all_summary
all_feat = y[:, patch_gen.num_skip :]
num_skip = self.model.patch_generator.num_skip
patch_size = self.model.patch_generator.patch_size
num_cls_tokens = self.model.patch_generator.num_cls_tokens
if imgs_sizes is None:
all_summary = y[:, :num_cls_tokens]
all_feat = y[:, num_skip:]
else:
all_patches = []
summaries = []
current_pos = 0
for num_patches in calc_seq_lens(imgs_sizes, patch_size):
patches = y[
:, current_pos + num_skip : current_pos + num_skip + num_patches, :
]
all_patches.append(patches)
summary = y[:, current_pos : current_pos + num_cls_tokens, :]
summaries.append(summary)
current_pos += num_skip + num_patches
all_summary = torch.cat(summaries, dim=1)
all_feat = torch.cat(all_patches, dim=1)
if self.summary_idxs is not None:
bb_summary = all_summary[:, self.summary_idxs]
else:
bb_summary = all_summary
return bb_summary.flatten(1), all_feat