[Model] Nemotron Parse 1.1 Support (#30864)
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -427,15 +427,17 @@ class RadioInternVisionModel(nn.Module):
|
||||
to_2tuple(config.patch_size), config.image_size
|
||||
)
|
||||
max_img_size = int(
|
||||
round(config.max_img_size / config.patch_size) * config.patch_size
|
||||
round(config.cpe_max_size / config.patch_size) * config.patch_size
|
||||
)
|
||||
unique_teachers = set(t["name"] for t in config.teachers)
|
||||
self.patch_generator = ViTPatchGenerator(
|
||||
config.patch_size,
|
||||
config.hidden_size,
|
||||
input_dims=self.img_size,
|
||||
max_input_dims=max_img_size,
|
||||
cls_token=True,
|
||||
register_multiple=config.reg_tokens,
|
||||
num_cls_tokens=len(unique_teachers) if config.cls_token_per_teacher else 1,
|
||||
register_multiple=config.register_multiple,
|
||||
)
|
||||
|
||||
self.encoder = InternVisionEncoder(
|
||||
@@ -489,11 +491,20 @@ class RadioModel(nn.Module):
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
summary_idxs = None
|
||||
if config.teachers:
|
||||
summary_idxs = torch.tensor(
|
||||
[i for i, t in enumerate(config.teachers) if t.get("use_summary", True)]
|
||||
)
|
||||
if summary_idxs.numel() > 0:
|
||||
self.register_buffer("summary_idxs", summary_idxs)
|
||||
self.summary_idxs = summary_idxs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
pixel_embeds: torch.Tensor | None = None,
|
||||
) -> torch.FloatTensor:
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
y = self.model(pixel_values)
|
||||
return self._extract_final(y)
|
||||
|
||||
@@ -546,10 +557,17 @@ class RadioModel(nn.Module):
|
||||
|
||||
return loaded_params
|
||||
|
||||
def _extract_final(self, y: torch.Tensor):
|
||||
def _extract_final(
|
||||
self, y: torch.Tensor
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# Remove CLS + REGISTERS tokens
|
||||
patch_gen = getattr(self.model, "patch_generator", None)
|
||||
if patch_gen is not None:
|
||||
all_summary = y[:, : patch_gen.num_cls_tokens]
|
||||
if self.summary_idxs is not None:
|
||||
bb_summary = all_summary[:, self.summary_idxs]
|
||||
else:
|
||||
bb_summary = all_summary
|
||||
all_feat = y[:, patch_gen.num_skip :]
|
||||
|
||||
return all_feat
|
||||
return bb_summary.flatten(1), all_feat
|
||||
|
||||
Reference in New Issue
Block a user