Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -25,25 +25,44 @@ from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||
MultiModalEmbeddings,
|
||||
SupportsMultiModal)
|
||||
from vllm.model_executor.models.internvl import (calculate_internvl_targets,
|
||||
get_internvl_target_ratios)
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
HasInnerState,
|
||||
IsHybrid,
|
||||
MultiModalEmbeddings,
|
||||
SupportsMultiModal,
|
||||
)
|
||||
from vllm.model_executor.models.internvl import (
|
||||
calculate_internvl_targets,
|
||||
get_internvl_target_ratios,
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
|
||||
from vllm.model_executor.models.radio import RadioModel
|
||||
from vllm.model_executor.models.utils import (flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix)
|
||||
from vllm.model_executor.models.utils import (
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargs, MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.radio import RadioConfig
|
||||
@@ -87,8 +106,9 @@ class NanoNemotronVLImageEmbeddinInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
NanoNemotronVLImageInputs = Union[NanoNemotronVLImagePixelInputs,
|
||||
NanoNemotronVLImageEmbeddinInputs]
|
||||
NanoNemotronVLImageInputs = Union[
|
||||
NanoNemotronVLImagePixelInputs, NanoNemotronVLImageEmbeddinInputs
|
||||
]
|
||||
|
||||
|
||||
class NanoNemotronVLVideoPixelInputs(TensorSchema):
|
||||
@@ -100,6 +120,7 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema):
|
||||
- h: Height of each video frame
|
||||
- w: Width of each video frame
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
|
||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
@@ -112,21 +133,19 @@ class NanoNemotronVLVideoEmbeddingInputs(TensorSchema):
|
||||
- f: Total video feature size
|
||||
- h: Hidden size (must match the hidden size of language model backbone)
|
||||
"""
|
||||
|
||||
type: Literal["video_embeds"]
|
||||
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("n", "f", "h")]
|
||||
data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")]
|
||||
|
||||
|
||||
NanoNemotronVLVideoInputs = Union[NanoNemotronVLVideoPixelInputs,
|
||||
NanoNemotronVLVideoEmbeddingInputs]
|
||||
NanoNemotronVLVideoInputs = Union[
|
||||
NanoNemotronVLVideoPixelInputs, NanoNemotronVLVideoEmbeddingInputs
|
||||
]
|
||||
|
||||
|
||||
def dynamic_preprocess(image,
|
||||
*,
|
||||
image_size=512,
|
||||
max_num_tiles=12,
|
||||
use_thumbnail=True,
|
||||
idx=0):
|
||||
def dynamic_preprocess(
|
||||
image, *, image_size=512, max_num_tiles=12, use_thumbnail=True, idx=0
|
||||
):
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
target_ratios = get_internvl_target_ratios(1, max_num_tiles)
|
||||
@@ -136,7 +155,8 @@ def dynamic_preprocess(image,
|
||||
orig_height=orig_height,
|
||||
target_ratios=target_ratios,
|
||||
image_size=image_size,
|
||||
use_thumbnail=False)
|
||||
use_thumbnail=False,
|
||||
)
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
@@ -156,12 +176,12 @@ def dynamic_preprocess(image,
|
||||
processed_images.append(thumbnail_img)
|
||||
|
||||
processed_images = [
|
||||
img.convert("RGB") if img.mode != "RGB" else img
|
||||
for img in processed_images
|
||||
img.convert("RGB") if img.mode != "RGB" else img for img in processed_images
|
||||
]
|
||||
processed_images = [
|
||||
T.Resize((image_size, image_size),
|
||||
interpolation=T.InterpolationMode.BICUBIC)(img)
|
||||
T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)(
|
||||
img
|
||||
)
|
||||
for img in processed_images
|
||||
]
|
||||
processed_images = [T.ToTensor()(img) for img in processed_images]
|
||||
@@ -222,8 +242,9 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig, tokenizer: AnyTokenizer,
|
||||
*args, **kwargs) -> None:
|
||||
def __init__(
|
||||
self, config: PretrainedConfig, tokenizer: AnyTokenizer, *args, **kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@@ -233,7 +254,8 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
patch_size: int = config.patch_size
|
||||
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size)**2 * (config.downsample_ratio**2))
|
||||
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
|
||||
)
|
||||
self.image_size = image_size
|
||||
self.use_thumbnail: bool = config.use_thumbnail
|
||||
self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1)
|
||||
@@ -283,7 +305,8 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
max_num=max_num_tiles,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
idx=idx,
|
||||
) for idx, image in enumerate(images)
|
||||
)
|
||||
for idx, image in enumerate(images)
|
||||
]
|
||||
|
||||
def _preprocess_image(
|
||||
@@ -295,24 +318,22 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
pixel_values_lst = self._images_to_pixel_values_lst(
|
||||
images, max_num_tiles)
|
||||
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
|
||||
image_inputs = {
|
||||
"pixel_values_flat":
|
||||
torch.cat(pixel_values_lst),
|
||||
"image_num_patches":
|
||||
torch.tensor([len(item) for item in pixel_values_lst]),
|
||||
"pixel_values_flat": torch.cat(pixel_values_lst),
|
||||
"image_num_patches": torch.tensor(
|
||||
[len(item) for item in pixel_values_lst]
|
||||
),
|
||||
}
|
||||
|
||||
for pixel_values in pixel_values_lst:
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
text = [t.replace('<image>', image_repl.full, 1) for t in text]
|
||||
text = [t.replace("<image>", image_repl.full, 1) for t in text]
|
||||
return text, image_inputs
|
||||
|
||||
def _make_batch_input(self,
|
||||
input_item: Optional[Union[Any, list[Any]]] = None):
|
||||
def _make_batch_input(self, input_item: Optional[Union[Any, list[Any]]] = None):
|
||||
if input_item is None:
|
||||
input_item = []
|
||||
if not isinstance(input_item, list):
|
||||
@@ -392,14 +413,14 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
max_num_tiles: int,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> list[torch.Tensor]:
|
||||
|
||||
return [
|
||||
video_to_pixel_values(
|
||||
video,
|
||||
input_size=self.image_size,
|
||||
max_num_tiles=max_num_tiles,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
) for video in videos
|
||||
)
|
||||
for video in videos
|
||||
]
|
||||
|
||||
def _preprocess_video(
|
||||
@@ -419,18 +440,19 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
)
|
||||
|
||||
video_inputs = {
|
||||
"pixel_values_flat_video":
|
||||
torch.cat(pixel_values_lst_video),
|
||||
"video_num_patches":
|
||||
torch.tensor([len(item) for item in pixel_values_lst_video]),
|
||||
"pixel_values_flat_video": torch.cat(pixel_values_lst_video),
|
||||
"video_num_patches": torch.tensor(
|
||||
[len(item) for item in pixel_values_lst_video]
|
||||
),
|
||||
}
|
||||
|
||||
for pixel_values in pixel_values_lst_video:
|
||||
num_patches = pixel_values.shape[0]
|
||||
|
||||
video_repl = self.get_video_repl(self.num_image_token,
|
||||
num_patches, self.video_token)
|
||||
text = [t.replace('<video>', video_repl.full, 1) for t in text]
|
||||
video_repl = self.get_video_repl(
|
||||
self.num_image_token, num_patches, self.video_token
|
||||
)
|
||||
text = [t.replace("<video>", video_repl.full, 1) for t in text]
|
||||
return text, video_inputs
|
||||
|
||||
def __call__(
|
||||
@@ -488,9 +510,9 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
repl_features = video_context_token * self.num_image_token
|
||||
repl_features_with_sep = IMG_START + repl_features + IMG_END
|
||||
# num_patches is equal to num_frames
|
||||
repl_full = ''.join([
|
||||
f'Frame{i+1}: {repl_features_with_sep}' for i in range(num_patches)
|
||||
])
|
||||
repl_full = "".join(
|
||||
[f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)]
|
||||
)
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, video_context_token)
|
||||
|
||||
@@ -525,8 +547,7 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
||||
max_num_tiles=max_num_tiles,
|
||||
)
|
||||
|
||||
def get_image_size_with_most_features(self,
|
||||
max_num_tiles: int) -> ImageSize:
|
||||
def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
base_size = processor.image_size
|
||||
@@ -544,8 +565,7 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = ImageSize(width=width,
|
||||
height=height)
|
||||
largest_feature_pinpoint = ImageSize(width=width, height=height)
|
||||
|
||||
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
||||
raise ValueError("Cannot have a largest feature size of 0!")
|
||||
@@ -557,7 +577,8 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
||||
# Use default max_num_tiles for max tokens calculation
|
||||
max_num_tiles = 12
|
||||
target_width, target_height = self.get_image_size_with_most_features(
|
||||
max_num_tiles)
|
||||
max_num_tiles
|
||||
)
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
@@ -571,7 +592,7 @@ _I = TypeVar("_I", bound=BaseNanoNemotronVLProcessingInfo)
|
||||
|
||||
|
||||
class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
""" ProcessingInfo extended for video processing"""
|
||||
"""ProcessingInfo extended for video processing"""
|
||||
|
||||
@property
|
||||
def supports_video(self):
|
||||
@@ -595,8 +616,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
processor = self.get_hf_processor() # we get the CustomProcessor here
|
||||
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = (seq_len -
|
||||
max_image_tokens) // processor.num_image_token
|
||||
max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
|
||||
max_frames_per_video = max_total_frames // max(max_videos, 1)
|
||||
|
||||
max_frames_per_video = min(max_frames_per_video, MAX_FRAMES)
|
||||
@@ -649,7 +669,8 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
return dict(
|
||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_patches),
|
||||
"image", image_num_patches
|
||||
),
|
||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
@@ -675,7 +696,8 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
def get_replacement_custom(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems)
|
||||
)
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
feature_size = images.get_feature_size(item_idx)
|
||||
@@ -694,9 +716,9 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
local_image_num_patches = image_num_patches
|
||||
if isinstance(local_image_num_patches, torch.Tensor):
|
||||
local_image_num_patches = local_image_num_patches.tolist()
|
||||
if isinstance(
|
||||
local_image_num_patches,
|
||||
(list, tuple)) and item_idx < len(local_image_num_patches):
|
||||
if isinstance(local_image_num_patches, (list, tuple)) and item_idx < len(
|
||||
local_image_num_patches
|
||||
):
|
||||
num_patches = int(local_image_num_patches[item_idx])
|
||||
|
||||
return hf_processor.get_image_repl(feature_size, num_patches)
|
||||
@@ -711,7 +733,8 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
|
||||
class NanoNemotronVLMultiModalProcessor(
|
||||
NanoNemotronBaseVLMultiModalProcessor[NanoNemotronVLProcessingInfo]):
|
||||
NanoNemotronBaseVLMultiModalProcessor[NanoNemotronVLProcessingInfo]
|
||||
):
|
||||
"""MultiModalProcessor extended for video support"""
|
||||
|
||||
def _call_hf_processor(
|
||||
@@ -721,12 +744,15 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(prompt, mm_data,
|
||||
mm_kwargs, tok_kwargs)
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt, mm_data, mm_kwargs, tok_kwargs
|
||||
)
|
||||
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
if self.info.supports_video and (
|
||||
video_token_id := hf_processor.video_token_id) is not None:
|
||||
if (
|
||||
self.info.supports_video
|
||||
and (video_token_id := hf_processor.video_token_id) is not None
|
||||
):
|
||||
processed_outputs["video_token_id"] = torch.tensor(video_token_id)
|
||||
return processed_outputs
|
||||
|
||||
@@ -735,18 +761,17 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_fields = super()._get_mm_fields_config(hf_inputs,
|
||||
hf_processor_mm_kwargs)
|
||||
image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
|
||||
if self.info.supports_video:
|
||||
video_num_patches = hf_inputs.get("video_num_patches",
|
||||
torch.empty(0))
|
||||
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
|
||||
num_videos = len(video_num_patches)
|
||||
video_fields = dict(
|
||||
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_patches),
|
||||
"video", video_num_patches
|
||||
),
|
||||
video_num_patches=MultiModalFieldConfig.batched("video"),
|
||||
video_token_id=MultiModalFieldConfig.shared(
|
||||
"video", num_videos))
|
||||
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
||||
)
|
||||
else:
|
||||
video_fields = {}
|
||||
|
||||
@@ -781,9 +806,8 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
assert isinstance(num_patches, int)
|
||||
|
||||
return hf_processor.get_video_repl(
|
||||
feature_size,
|
||||
num_patches,
|
||||
video_context_token=hf_processor.video_token)
|
||||
feature_size, num_patches, video_context_token=hf_processor.video_token
|
||||
)
|
||||
|
||||
if self.info.supports_video:
|
||||
prompt_repl = [
|
||||
@@ -792,7 +816,7 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
modality="video",
|
||||
target="<video>",
|
||||
replacement=get_video_replacement_internvl,
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
return prompt_repl
|
||||
@@ -814,23 +838,26 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
) -> MultiModalDataDict:
|
||||
# Use default max_num_tiles for dummy data generation
|
||||
max_num_tiles = 12
|
||||
target_width, target_height = (
|
||||
self.info.get_image_size_with_most_features(max_num_tiles))
|
||||
target_width, target_height = self.info.get_image_size_with_most_features(
|
||||
max_num_tiles
|
||||
)
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
"image": self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class NanoNemotronVLDummyInputsBuilder(
|
||||
NanoNemotronVLDummyInputsBuilder[NanoNemotronVLProcessingInfo]):
|
||||
NanoNemotronVLDummyInputsBuilder[NanoNemotronVLProcessingInfo]
|
||||
):
|
||||
"""DummyInputsBuilder extended for video support"""
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
@@ -844,23 +871,25 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
dummy_image = super().get_dummy_mm_data(seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
mm_options=mm_options)
|
||||
dummy_image = super().get_dummy_mm_data(
|
||||
seq_len=seq_len, mm_counts=mm_counts, mm_options=mm_options
|
||||
)
|
||||
if self.info.supports_video:
|
||||
config = self.info.get_hf_config()
|
||||
image_size: int = config.force_image_size
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
target_num_frames = self.info.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts
|
||||
)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
dummy_video = {
|
||||
"video":
|
||||
self._get_dummy_videos(width=image_size,
|
||||
height=image_size,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides)
|
||||
"video": self._get_dummy_videos(
|
||||
width=image_size,
|
||||
height=image_size,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
)
|
||||
}
|
||||
else:
|
||||
dummy_video = {}
|
||||
@@ -872,9 +901,7 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
info=NanoNemotronVLProcessingInfo,
|
||||
dummy_inputs=NanoNemotronVLDummyInputsBuilder,
|
||||
)
|
||||
class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
SupportsMultiModal):
|
||||
|
||||
class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModal):
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
@@ -892,7 +919,8 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
self.patch_size = patch_size
|
||||
self.template = config.template
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size)**2 * (config.downsample_ratio**2))
|
||||
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
|
||||
)
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.ps_version = config.ps_version
|
||||
self.image_tag_type = config.image_tag_type
|
||||
@@ -903,7 +931,8 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.vision_model = self.get_vit_model_from_radio_config(config).to(
|
||||
self.language_model.config.torch_dtype)
|
||||
self.language_model.config.torch_dtype
|
||||
)
|
||||
|
||||
# Construct the vision projection.
|
||||
vit_hidden_size = config.vit_hidden_size
|
||||
@@ -911,18 +940,17 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
self.mlp1 = nn.Sequential(
|
||||
RMSNorm(hidden_size=vit_hidden_size *
|
||||
int(1 / self.downsample_ratio)**2,
|
||||
eps=1e-5),
|
||||
RMSNorm(
|
||||
hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
|
||||
eps=1e-5,
|
||||
),
|
||||
nn.Linear(
|
||||
vit_hidden_size * int(1 / self.downsample_ratio)**2,
|
||||
vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
|
||||
vision_projection_hidden_size,
|
||||
bias=False,
|
||||
),
|
||||
ReLUSquaredActivation(),
|
||||
nn.Linear(vision_projection_hidden_size,
|
||||
llm_hidden_size,
|
||||
bias=False),
|
||||
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
|
||||
)
|
||||
self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype)
|
||||
|
||||
@@ -962,17 +990,16 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
def extract_feature(self, pixel_values):
|
||||
vit_embeds = self.vision_model(pixel_values)
|
||||
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
|
||||
h = w = int(vit_embeds.shape[1]**0.5)
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(vit_embeds,
|
||||
scale_factor=self.downsample_ratio)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
|
||||
vit_embeds.shape[-1])
|
||||
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
return vit_embeds
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[NanoNemotronVLImageInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[NanoNemotronVLImageInputs]:
|
||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
@@ -982,8 +1009,10 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
raise ValueError(
|
||||
"Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}"
|
||||
)
|
||||
|
||||
return NanoNemotronVLImageEmbeddinInputs(
|
||||
type="image_embeds",
|
||||
@@ -996,12 +1025,16 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
|
||||
if pixel_values_flat is not None:
|
||||
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat)}")
|
||||
raise ValueError(
|
||||
"Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat)}"
|
||||
)
|
||||
|
||||
if not isinstance(image_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(image_num_patches)}")
|
||||
raise ValueError(
|
||||
"Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(image_num_patches)}"
|
||||
)
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
@@ -1015,7 +1048,8 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: NanoNemotronVLImageInputs) -> torch.Tensor:
|
||||
self, image_input: NanoNemotronVLImageInputs
|
||||
) -> torch.Tensor:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
@@ -1026,22 +1060,20 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
|
||||
# Only one image in the current batch
|
||||
if len(num_patches) == 1:
|
||||
return (image_embeds.view(-1,
|
||||
self.config.text_config.hidden_size), )
|
||||
return (image_embeds.view(-1, self.config.text_config.hidden_size),)
|
||||
|
||||
# NOTE: Image embeddings are split into separate tensors for each image
|
||||
# by the size of each embedding.
|
||||
feature_size = image_embeds.shape[1]
|
||||
image_embeds = image_embeds.view(-1,
|
||||
self.config.text_config.hidden_size)
|
||||
image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
|
||||
image_feature_sizes = [
|
||||
num_patches * feature_size for num_patches in num_patches
|
||||
]
|
||||
return image_embeds.split(image_feature_sizes)
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self,
|
||||
**kwargs: object) -> Optional[NanoNemotronVLVideoPixelInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[NanoNemotronVLVideoPixelInputs]:
|
||||
pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
|
||||
video_num_patches = kwargs.pop("video_num_patches", None)
|
||||
video_embeds = kwargs.pop("video_embeds", None)
|
||||
@@ -1061,15 +1093,18 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
|
||||
if pixel_values_flat_video is not None:
|
||||
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat_video)}")
|
||||
raise ValueError(
|
||||
"Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat_video)}"
|
||||
)
|
||||
|
||||
if not isinstance(video_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(video_num_patches)}")
|
||||
raise ValueError(
|
||||
"Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(video_num_patches)}"
|
||||
)
|
||||
|
||||
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
|
||||
concat=True)
|
||||
pixel_values_flat_video = flatten_bn(pixel_values_flat_video, concat=True)
|
||||
video_num_patches = flatten_bn(video_num_patches, concat=True)
|
||||
expected_h = expected_w = self.config.force_image_size
|
||||
resolve_bindings = {"h": expected_h, "w": expected_w}
|
||||
@@ -1088,19 +1123,17 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values_flat",
|
||||
"image_embeds") and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if input_key in ("pixel_values_flat_video",
|
||||
) and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
if (
|
||||
input_key in ("pixel_values_flat", "image_embeds")
|
||||
and "images" not in modalities
|
||||
):
|
||||
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
|
||||
if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
# Validate the multimodal input keyword arguments
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if modalities is None:
|
||||
@@ -1193,16 +1226,13 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
default_weight_loader(param, w)
|
||||
elif is_vision_weights(name):
|
||||
# Convert: vision_model.radio_model.* → radio_model.*
|
||||
hf_key = name[len(
|
||||
"vision_model."):] # Remove "vision_model." prefix
|
||||
hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix
|
||||
vision_weights.append((hf_key, w))
|
||||
|
||||
self.language_model.load_weights(llm_weights)
|
||||
self.vision_model.load_weights(vision_weights)
|
||||
|
||||
def print_architecture(self,
|
||||
detailed: bool = True,
|
||||
save_to_file: str = None):
|
||||
def print_architecture(self, detailed: bool = True, save_to_file: str = None):
|
||||
"""
|
||||
Print model architecture with parameter names, shapes, and sizes.
|
||||
|
||||
@@ -1238,20 +1268,26 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
# Group parameters by main component
|
||||
if name.startswith("language_model"):
|
||||
param_groups["language_model"].append(
|
||||
(name, param.shape, param_size, param.dtype))
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
elif name.startswith("vision_model"):
|
||||
param_groups["vision_model"].append(
|
||||
(name, param.shape, param_size, param.dtype))
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
elif name.startswith("mlp1"):
|
||||
param_groups["mlp1"].append(
|
||||
(name, param.shape, param_size, param.dtype))
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
else:
|
||||
param_groups["other"].append(
|
||||
(name, param.shape, param_size, param.dtype))
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
|
||||
if detailed:
|
||||
print(f"{name:<70} | Shape: {str(param.shape):<25} | "
|
||||
f"Size: {param_size:>12,} | Dtype: {param.dtype}")
|
||||
print(
|
||||
f"{name:<70} | Shape: {str(param.shape):<25} | "
|
||||
f"Size: {param_size:>12,} | Dtype: {param.dtype}"
|
||||
)
|
||||
|
||||
print("=" * 100)
|
||||
print("Summary by Component:")
|
||||
@@ -1260,11 +1296,16 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
for component, params in param_groups.items():
|
||||
if params: # Only show components that have parameters
|
||||
component_total = sum(size for _, _, size, _ in params)
|
||||
percentage = ((component_total / total_params) *
|
||||
100 if total_params > 0 else 0)
|
||||
print(f"{component:<20} | Parameters: {len(params):>4} | "
|
||||
f"Total Size: {component_total:>15,} | "
|
||||
f"{percentage:>6.2f}%")
|
||||
percentage = (
|
||||
(component_total / total_params) * 100
|
||||
if total_params > 0
|
||||
else 0
|
||||
)
|
||||
print(
|
||||
f"{component:<20} | Parameters: {len(params):>4} | "
|
||||
f"Total Size: {component_total:>15,} | "
|
||||
f"{percentage:>6.2f}%"
|
||||
)
|
||||
|
||||
print("-" * 60)
|
||||
print(f"{'Total Parameters':<20} | {total_params:>15,}")
|
||||
@@ -1320,10 +1361,9 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
hf_config_vision = hf_config.vision_config
|
||||
model_name = hf_config_vision.args.get("model")
|
||||
if model_name is None:
|
||||
raise ValueError(f'Unsupported vit model type: {model_name}')
|
||||
raise ValueError(f"Unsupported vit model type: {model_name}")
|
||||
|
||||
preferred_resolution = getattr(hf_config_vision,
|
||||
"preferred_resolution", None)
|
||||
preferred_resolution = getattr(hf_config_vision, "preferred_resolution", None)
|
||||
image_size = preferred_resolution[0] if preferred_resolution else 224
|
||||
patch_size = getattr(hf_config_vision, "patch_size", 16)
|
||||
|
||||
@@ -1333,33 +1373,36 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
patch_size=patch_size,
|
||||
norm_mean=hf_config.norm_mean,
|
||||
norm_std=hf_config.norm_std,
|
||||
reg_tokens=(hf_config_vision.args.get("register_multiple")
|
||||
if hasattr(hf_config_vision, "args")
|
||||
and isinstance(hf_config_vision.args, dict) else None),
|
||||
reg_tokens=(
|
||||
hf_config_vision.args.get("register_multiple")
|
||||
if hasattr(hf_config_vision, "args")
|
||||
and isinstance(hf_config_vision.args, dict)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
return RadioModel(config=radio_config)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.language_model.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
input_buffers, **kwargs
|
||||
)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return (self.language_model.mamba_cache.
|
||||
get_seqlen_agnostic_capture_inputs(batch_size))
|
||||
return self.language_model.mamba_cache.get_seqlen_agnostic_capture_inputs(
|
||||
batch_size
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(cls, vllm_config: "VllmConfig"):
|
||||
text_config = vllm_config.model_config.hf_config.text_config
|
||||
temp_vllm_config = copy.deepcopy(vllm_config)
|
||||
temp_vllm_config.model_config.hf_config = text_config
|
||||
return NemotronHForCausalLM.get_mamba_state_shape_from_config(
|
||||
temp_vllm_config)
|
||||
return NemotronHForCausalLM.get_mamba_state_shape_from_config(temp_vllm_config)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(cls, vllm_config: "VllmConfig"):
|
||||
text_config = vllm_config.model_config.hf_config.text_config
|
||||
temp_vllm_config = copy.deepcopy(vllm_config)
|
||||
temp_vllm_config.model_config.hf_config = text_config
|
||||
return NemotronHForCausalLM.get_mamba_state_dtype_from_config(
|
||||
temp_vllm_config)
|
||||
return NemotronHForCausalLM.get_mamba_state_dtype_from_config(temp_vllm_config)
|
||||
|
||||
Reference in New Issue
Block a user