[Model] Nemotron Parse 1.1 Support (#30864)

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
amitz-nv
2026-01-05 23:00:14 +02:00
committed by GitHub
parent af1b07b0c5
commit ee21291825
13 changed files with 1117 additions and 31 deletions

View File

@@ -427,15 +427,17 @@ class RadioInternVisionModel(nn.Module):
to_2tuple(config.patch_size), config.image_size
)
max_img_size = int(
round(config.max_img_size / config.patch_size) * config.patch_size
round(config.cpe_max_size / config.patch_size) * config.patch_size
)
unique_teachers = set(t["name"] for t in config.teachers)
self.patch_generator = ViTPatchGenerator(
config.patch_size,
config.hidden_size,
input_dims=self.img_size,
max_input_dims=max_img_size,
cls_token=True,
register_multiple=config.reg_tokens,
num_cls_tokens=len(unique_teachers) if config.cls_token_per_teacher else 1,
register_multiple=config.register_multiple,
)
self.encoder = InternVisionEncoder(
@@ -489,11 +491,20 @@ class RadioModel(nn.Module):
prefix=prefix,
)
summary_idxs = None
if config.teachers:
summary_idxs = torch.tensor(
[i for i, t in enumerate(config.teachers) if t.get("use_summary", True)]
)
if summary_idxs.numel() > 0:
self.register_buffer("summary_idxs", summary_idxs)
self.summary_idxs = summary_idxs
def forward(
self,
pixel_values: torch.Tensor | None = None,
pixel_embeds: torch.Tensor | None = None,
) -> torch.FloatTensor:
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
y = self.model(pixel_values)
return self._extract_final(y)
@@ -546,10 +557,17 @@ class RadioModel(nn.Module):
return loaded_params
def _extract_final(self, y: torch.Tensor):
def _extract_final(
self, y: torch.Tensor
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# Remove CLS + REGISTERS tokens
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
all_summary = y[:, : patch_gen.num_cls_tokens]
if self.summary_idxs is not None:
bb_summary = all_summary[:, self.summary_idxs]
else:
bb_summary = all_summary
all_feat = y[:, patch_gen.num_skip :]
return all_feat
return bb_summary.flatten(1), all_feat