[Misc] Clean up processing logic (#37541)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1221,49 +1221,33 @@ class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessing
|
|||||||
num_videos: int,
|
num_videos: int,
|
||||||
overrides: VideoDummyOptions | None = None,
|
overrides: VideoDummyOptions | None = None,
|
||||||
):
|
):
|
||||||
if overrides:
|
# ernie4.5-vl requires at least 2 frames
|
||||||
if overrides.num_frames:
|
num_frames = max(num_frames, 2)
|
||||||
if overrides.num_frames > num_frames:
|
if overrides and overrides.num_frames:
|
||||||
logger.warning(
|
overrides.num_frames = max(overrides.num_frames, 2)
|
||||||
"video.num_frames override (%d) exceeds model's "
|
|
||||||
"maximum number of frames (%d), will be ignored",
|
videos = super()._get_dummy_videos(
|
||||||
overrides.num_frames,
|
width=width,
|
||||||
num_frames,
|
height=height,
|
||||||
)
|
num_frames=num_frames,
|
||||||
num_frames = min(num_frames, overrides.num_frames)
|
num_videos=num_videos,
|
||||||
if overrides.width:
|
overrides=overrides,
|
||||||
if overrides.width > width:
|
)
|
||||||
logger.warning(
|
videos = [v.copy() for v in videos]
|
||||||
"video.width override (%d) exceeds model's "
|
|
||||||
"maximum width (%d), will be ignored",
|
|
||||||
overrides.width,
|
|
||||||
width,
|
|
||||||
)
|
|
||||||
width = min(width, overrides.width)
|
|
||||||
if overrides.height:
|
|
||||||
if overrides.height > height:
|
|
||||||
logger.warning(
|
|
||||||
"video.height override (%d) exceeds model's "
|
|
||||||
"maximum height (%d), will be ignored",
|
|
||||||
overrides.height,
|
|
||||||
height,
|
|
||||||
)
|
|
||||||
height = min(height, overrides.height)
|
|
||||||
num_frames = max(num_frames, 2) # ernie4.5-vl requires at least 2 frames
|
|
||||||
|
|
||||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
|
||||||
video_items = []
|
video_items = []
|
||||||
for i in range(num_videos):
|
for video in videos:
|
||||||
|
video_num_frames = video.shape[0]
|
||||||
video_metadata = {
|
video_metadata = {
|
||||||
"fps": 2.0,
|
"fps": 2.0,
|
||||||
"duration": num_frames / 2.0,
|
"duration": video_num_frames / 2.0,
|
||||||
"total_num_frames": num_frames,
|
"total_num_frames": video_num_frames,
|
||||||
"frames_indices": [i for i in range(num_frames)],
|
"frames_indices": list(range(video_num_frames)),
|
||||||
"video_backend": "opencv",
|
"video_backend": "opencv",
|
||||||
"do_sample_frames": False,
|
"do_sample_frames": False,
|
||||||
}
|
}
|
||||||
video_item = (video.copy(), video_metadata)
|
video_items.append((video, video_metadata))
|
||||||
video_items.append(video_item)
|
|
||||||
return video_items
|
return video_items
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1206,49 +1206,32 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
|
|||||||
num_videos: int,
|
num_videos: int,
|
||||||
overrides: VideoDummyOptions | None = None,
|
overrides: VideoDummyOptions | None = None,
|
||||||
) -> list[VideoItem]:
|
) -> list[VideoItem]:
|
||||||
if overrides:
|
# GLM 4.6V requires at least 2 frames
|
||||||
if overrides.num_frames:
|
num_frames = max(num_frames, 2)
|
||||||
if overrides.num_frames > num_frames:
|
if overrides and overrides.num_frames:
|
||||||
logger.warning(
|
overrides.num_frames = max(overrides.num_frames, 2)
|
||||||
"video.num_frames override (%d) exceeds model's "
|
|
||||||
"maximum number of frames (%d), will be ignored",
|
videos = super()._get_dummy_videos(
|
||||||
overrides.num_frames,
|
width=width,
|
||||||
num_frames,
|
height=height,
|
||||||
)
|
num_frames=num_frames,
|
||||||
num_frames = min(num_frames, overrides.num_frames)
|
num_videos=num_videos,
|
||||||
if overrides.width:
|
overrides=overrides,
|
||||||
if overrides.width > width:
|
)
|
||||||
logger.warning(
|
videos = [v.copy() for v in videos]
|
||||||
"video.width override (%d) exceeds model's "
|
|
||||||
"maximum width (%d), will be ignored",
|
|
||||||
overrides.width,
|
|
||||||
width,
|
|
||||||
)
|
|
||||||
width = min(width, overrides.width)
|
|
||||||
if overrides.height:
|
|
||||||
if overrides.height > height:
|
|
||||||
logger.warning(
|
|
||||||
"video.height override (%d) exceeds model's "
|
|
||||||
"maximum height (%d), will be ignored",
|
|
||||||
overrides.height,
|
|
||||||
height,
|
|
||||||
)
|
|
||||||
height = min(height, overrides.height)
|
|
||||||
|
|
||||||
num_frames = max(num_frames, 2) # GLM 4.6V requires 2 frames
|
|
||||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
|
||||||
video_items = []
|
video_items = []
|
||||||
for i in range(num_videos):
|
for video in videos:
|
||||||
|
video_num_frames = video.shape[0]
|
||||||
video_metadata = {
|
video_metadata = {
|
||||||
"fps": 2.0,
|
"fps": 2.0,
|
||||||
"duration": num_frames / 2.0,
|
"duration": video_num_frames / 2.0,
|
||||||
"total_num_frames": num_frames,
|
"total_num_frames": video_num_frames,
|
||||||
"frames_indices": [i for i in range(num_frames)],
|
"frames_indices": list(range(video_num_frames)),
|
||||||
"video_backend": "opencv",
|
"video_backend": "opencv",
|
||||||
"do_sample_frames": False,
|
"do_sample_frames": False,
|
||||||
}
|
}
|
||||||
video_item = (video.copy(), video_metadata)
|
video_items.append((video, video_metadata))
|
||||||
video_items.append(video_item)
|
|
||||||
|
|
||||||
return video_items
|
return video_items
|
||||||
|
|
||||||
|
|||||||
@@ -8,14 +8,13 @@
|
|||||||
# Copyright (c) 2024 H2O.AI
|
# Copyright (c) 2024 H2O.AI
|
||||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalKwargsItems
|
from vllm.multimodal.inputs import BatchedTensorInputs
|
||||||
from vllm.multimodal.parse import (
|
from vllm.multimodal.parse import (
|
||||||
ImageEmbeddingItems,
|
ImageEmbeddingItems,
|
||||||
ImageProcessorItems,
|
ImageProcessorItems,
|
||||||
@@ -25,7 +24,6 @@ from vllm.multimodal.processing.processor import (
|
|||||||
MultiModalProcessingInfo,
|
MultiModalProcessingInfo,
|
||||||
ProcessorInputs,
|
ProcessorInputs,
|
||||||
PromptReplacement,
|
PromptReplacement,
|
||||||
PromptUpdate,
|
|
||||||
TimingContext,
|
TimingContext,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.processors.h2ovl import H2OVLImageProcessor, H2OVLProcessor
|
from vllm.transformers_utils.processors.h2ovl import H2OVLImageProcessor, H2OVLProcessor
|
||||||
@@ -86,15 +84,12 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
|
|||||||
|
|
||||||
|
|
||||||
class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]):
|
class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]):
|
||||||
def _get_prompt_updates(
|
def _get_prompt_repl_image(
|
||||||
self,
|
self,
|
||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor: H2OVLProcessor,
|
||||||
out_mm_kwargs: MultiModalKwargsItems,
|
out_mm_data: BatchedTensorInputs,
|
||||||
) -> Sequence[PromptUpdate]:
|
):
|
||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
||||||
|
|
||||||
out_mm_data = out_mm_kwargs.get_data()
|
|
||||||
if "image_num_patches" in out_mm_data:
|
if "image_num_patches" in out_mm_data:
|
||||||
image_num_patches = out_mm_data["image_num_patches"]
|
image_num_patches = out_mm_data["image_num_patches"]
|
||||||
assert isinstance(image_num_patches, torch.Tensor)
|
assert isinstance(image_num_patches, torch.Tensor)
|
||||||
@@ -130,13 +125,11 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn
|
|||||||
|
|
||||||
return hf_processor.get_image_repl(num_patches, num_features=feature_size)
|
return hf_processor.get_image_repl(num_patches, num_features=feature_size)
|
||||||
|
|
||||||
return [
|
return PromptReplacement(
|
||||||
PromptReplacement(
|
|
||||||
modality="image",
|
modality="image",
|
||||||
target="<image>",
|
target="<image>",
|
||||||
replacement=get_replacement_internvl,
|
replacement=get_replacement_internvl,
|
||||||
)
|
)
|
||||||
]
|
|
||||||
|
|
||||||
def _cached_apply_hf_processor(
|
def _cached_apply_hf_processor(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from vllm.model_executor.models.intern_vit import (
|
|||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (
|
from vllm.multimodal.inputs import (
|
||||||
|
BatchedTensorInputs,
|
||||||
MultiModalDataDict,
|
MultiModalDataDict,
|
||||||
MultiModalFieldConfig,
|
MultiModalFieldConfig,
|
||||||
MultiModalKwargsItems,
|
MultiModalKwargsItems,
|
||||||
@@ -238,11 +239,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_image_fields_config(self, hf_inputs: BatchFeature):
|
||||||
self,
|
|
||||||
hf_inputs: BatchFeature,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
|
||||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||||
num_images = len(image_num_patches)
|
num_images = len(image_num_patches)
|
||||||
|
|
||||||
@@ -255,15 +252,19 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return self._get_image_fields_config(hf_inputs)
|
||||||
|
|
||||||
|
def _get_prompt_repl_image(
|
||||||
self,
|
self,
|
||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor: InternVLProcessor,
|
||||||
out_mm_kwargs: MultiModalKwargsItems,
|
out_mm_data: BatchedTensorInputs,
|
||||||
) -> Sequence[PromptUpdate]:
|
):
|
||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
||||||
|
|
||||||
out_mm_data = out_mm_kwargs.get_data()
|
|
||||||
if "image_num_patches" in out_mm_data:
|
if "image_num_patches" in out_mm_data:
|
||||||
image_num_patches = out_mm_data["image_num_patches"]
|
image_num_patches = out_mm_data["image_num_patches"]
|
||||||
assert isinstance(image_num_patches, torch.Tensor)
|
assert isinstance(image_num_patches, torch.Tensor)
|
||||||
@@ -296,12 +297,23 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
|
|
||||||
return hf_processor.get_image_repl(num_patches, num_features=feature_size)
|
return hf_processor.get_image_repl(num_patches, num_features=feature_size)
|
||||||
|
|
||||||
return [
|
return PromptReplacement(
|
||||||
PromptReplacement(
|
|
||||||
modality="image",
|
modality="image",
|
||||||
target="<image>",
|
target="<image>",
|
||||||
replacement=get_replacement_internvl,
|
replacement=get_replacement_internvl,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargsItems,
|
||||||
|
) -> Sequence[PromptUpdate]:
|
||||||
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
out_mm_data = out_mm_kwargs.get_data()
|
||||||
|
|
||||||
|
return [
|
||||||
|
self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -455,44 +467,35 @@ class InternVLMultiModalProcessor(
|
|||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_video_fields_config(self, hf_inputs: BatchFeature):
|
||||||
self,
|
|
||||||
hf_inputs: BatchFeature,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
|
||||||
image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
|
|
||||||
if self.info.ctx_video_token:
|
|
||||||
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
|
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
|
||||||
num_videos = len(video_num_patches)
|
num_videos = len(video_num_patches)
|
||||||
video_fields = dict(
|
|
||||||
|
return dict(
|
||||||
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"video", video_num_patches
|
"video", video_num_patches
|
||||||
),
|
),
|
||||||
video_num_patches=MultiModalFieldConfig.batched("video"),
|
video_num_patches=MultiModalFieldConfig.batched("video"),
|
||||||
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
video_fields = {}
|
|
||||||
|
|
||||||
return image_fields | video_fields
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
fields = self._get_image_fields_config(hf_inputs)
|
||||||
|
if self.info.ctx_video_token:
|
||||||
|
fields |= self._get_video_fields_config(hf_inputs)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
return fields
|
||||||
|
|
||||||
|
def _get_prompt_repl_video(
|
||||||
self,
|
self,
|
||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor: InternVLProcessor,
|
||||||
out_mm_kwargs: MultiModalKwargsItems,
|
out_mm_data: BatchedTensorInputs,
|
||||||
) -> Sequence[PromptUpdate]:
|
):
|
||||||
prompt_repl = super()._get_prompt_updates(
|
|
||||||
mm_items=mm_items,
|
|
||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
|
||||||
out_mm_kwargs=out_mm_kwargs,
|
|
||||||
)
|
|
||||||
if self.info.ctx_video_token is None:
|
|
||||||
return prompt_repl
|
|
||||||
|
|
||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
||||||
|
|
||||||
out_mm_data = out_mm_kwargs.get_data()
|
|
||||||
if "video_num_patches" in out_mm_data:
|
if "video_num_patches" in out_mm_data:
|
||||||
video_num_patches = out_mm_data["video_num_patches"]
|
video_num_patches = out_mm_data["video_num_patches"]
|
||||||
assert isinstance(video_num_patches, torch.Tensor)
|
assert isinstance(video_num_patches, torch.Tensor)
|
||||||
@@ -507,14 +510,30 @@ class InternVLMultiModalProcessor(
|
|||||||
|
|
||||||
return hf_processor.get_video_repl(num_patches)
|
return hf_processor.get_video_repl(num_patches)
|
||||||
|
|
||||||
return [
|
return PromptReplacement(
|
||||||
*prompt_repl,
|
|
||||||
PromptReplacement(
|
|
||||||
modality="video",
|
modality="video",
|
||||||
target="<video>",
|
target="<video>",
|
||||||
replacement=get_video_replacement_internvl,
|
replacement=get_video_replacement_internvl,
|
||||||
),
|
)
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargsItems,
|
||||||
|
) -> Sequence[PromptUpdate]:
|
||||||
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
out_mm_data = out_mm_kwargs.get_data()
|
||||||
|
|
||||||
|
prompt_repls = [
|
||||||
|
self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data),
|
||||||
]
|
]
|
||||||
|
if self.info.ctx_video_token is not None:
|
||||||
|
prompt_repls.append(
|
||||||
|
self._get_prompt_repl_video(mm_items, hf_processor, out_mm_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt_repls
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
|||||||
@@ -1913,22 +1913,32 @@ class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]):
|
|||||||
height: int,
|
height: int,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
num_videos: int,
|
num_videos: int,
|
||||||
|
overrides: VideoDummyOptions | None = None,
|
||||||
) -> list[VideoItem]:
|
) -> list[VideoItem]:
|
||||||
video = np.full((num_frames, height, width, 3), 255, dtype=np.uint8)
|
videos = super()._get_dummy_videos(
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
num_frames=num_frames,
|
||||||
|
num_videos=num_videos,
|
||||||
|
overrides=overrides,
|
||||||
|
)
|
||||||
|
videos = [v.copy() for v in videos]
|
||||||
|
|
||||||
video_items = []
|
video_items = []
|
||||||
for i in range(num_videos):
|
for video in videos:
|
||||||
|
video_num_frames = video.shape[0]
|
||||||
video_metadata = {
|
video_metadata = {
|
||||||
"fps": 2.0,
|
"fps": 2.0,
|
||||||
"duration": num_frames / 2.0,
|
"duration": video_num_frames / 2.0,
|
||||||
"total_num_frames": num_frames,
|
"total_num_frames": video_num_frames,
|
||||||
"frames_indices": list(range(num_frames)),
|
"frames_indices": list(range(video_num_frames)),
|
||||||
"video_backend": "decord",
|
"video_backend": "decord",
|
||||||
"do_sample_frames": False,
|
"do_sample_frames": False,
|
||||||
"height": height,
|
"height": height,
|
||||||
"width": width,
|
"width": width,
|
||||||
}
|
}
|
||||||
video_item = (video.copy(), video_metadata)
|
video_items.append((video, video_metadata))
|
||||||
video_items.append(video_item)
|
|
||||||
return video_items
|
return video_items
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,10 +10,9 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from abc import abstractmethod
|
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Annotated, Literal, TypeAlias, TypeVar
|
from typing import Annotated, Literal, TypeAlias
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -47,6 +46,7 @@ from vllm.multimodal.evs import (
|
|||||||
)
|
)
|
||||||
from vllm.multimodal.inputs import (
|
from vllm.multimodal.inputs import (
|
||||||
AudioItem,
|
AudioItem,
|
||||||
|
BatchedTensorInputs,
|
||||||
MultiModalDataDict,
|
MultiModalDataDict,
|
||||||
MultiModalFieldConfig,
|
MultiModalFieldConfig,
|
||||||
MultiModalInputs,
|
MultiModalInputs,
|
||||||
@@ -196,21 +196,58 @@ NanoNemotronVLVideoInputs: TypeAlias = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
class NanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
||||||
"""Basic image-only ProcessingInfo for InternVL-style models."""
|
def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
|
||||||
|
return self.ctx.init_processor(
|
||||||
|
NanoNemotronVLProcessor,
|
||||||
|
config=self.get_hf_config(),
|
||||||
|
tokenizer=self.get_tokenizer(),
|
||||||
|
video_token=self.get_video_token(),
|
||||||
|
video_pruning_rate=self.get_video_pruning_rate(),
|
||||||
|
max_model_len=self.ctx.model_config.max_model_len,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@cached_property
|
||||||
def get_hf_processor(
|
def is_dynamic_tiler(self) -> bool:
|
||||||
self,
|
return self.get_hf_processor().dynamic_tiler is not None
|
||||||
**kwargs: object,
|
|
||||||
) -> BaseNanoNemotronVLProcessor:
|
@cached_property
|
||||||
raise NotImplementedError
|
def supports_video(self):
|
||||||
|
return self.get_hf_processor().supports_video
|
||||||
|
|
||||||
|
def get_video_token(self) -> str | None:
|
||||||
|
return IMG_CONTEXT
|
||||||
|
|
||||||
|
def get_video_pruning_rate(self) -> float | None:
|
||||||
|
return self.ctx.get_mm_config().video_pruning_rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_extractor(self) -> ParakeetExtractor | None:
|
||||||
|
return self.get_hf_processor().audio_extractor
|
||||||
|
|
||||||
def get_default_tok_params(self) -> TokenizeParams:
|
def get_default_tok_params(self) -> TokenizeParams:
|
||||||
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
|
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||||
return {"image": None}
|
image_limit = {"image": None}
|
||||||
|
video_limit = {"video": None} if self.supports_video else {}
|
||||||
|
audio_limit = {"audio": None} if self.audio_extractor is not None else {}
|
||||||
|
return {**image_limit, **video_limit, **audio_limit}
|
||||||
|
|
||||||
|
def get_data_parser(self):
|
||||||
|
target_sr = None
|
||||||
|
target_channels = None
|
||||||
|
if extractor := self.audio_extractor:
|
||||||
|
target_sr = extractor.sampling_rate
|
||||||
|
target_channels = 1
|
||||||
|
|
||||||
|
return MultiModalDataParser(
|
||||||
|
video_needs_metadata=True,
|
||||||
|
target_sr=target_sr,
|
||||||
|
target_channels=target_channels,
|
||||||
|
expected_hidden_size=self._get_expected_hidden_size(),
|
||||||
|
)
|
||||||
|
|
||||||
def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize:
|
def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize:
|
||||||
processor = self.get_hf_processor()
|
processor = self.get_hf_processor()
|
||||||
@@ -248,46 +285,6 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
|||||||
max_num_tiles=max_num_tiles,
|
max_num_tiles=max_num_tiles,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_I = TypeVar("_I", bound=BaseNanoNemotronVLProcessingInfo)
|
|
||||||
|
|
||||||
|
|
||||||
class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
|
||||||
"""ProcessingInfo extended for video processing"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_video(self):
|
|
||||||
return self.get_hf_processor().supports_video
|
|
||||||
|
|
||||||
@property
|
|
||||||
def audio_extractor(self) -> ParakeetExtractor | None:
|
|
||||||
return self.get_hf_processor().audio_extractor
|
|
||||||
|
|
||||||
def get_data_parser(self):
|
|
||||||
target_sr = None
|
|
||||||
target_channels = None
|
|
||||||
if extractor := self.audio_extractor:
|
|
||||||
target_sr = extractor.sampling_rate
|
|
||||||
target_channels = 1
|
|
||||||
|
|
||||||
return MultiModalDataParser(
|
|
||||||
video_needs_metadata=True,
|
|
||||||
target_sr=target_sr,
|
|
||||||
target_channels=target_channels,
|
|
||||||
expected_hidden_size=self._get_expected_hidden_size(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_supported_mm_limits(self):
|
|
||||||
video_limit = {"video": None} if self.supports_video else {}
|
|
||||||
audio_limit = {"audio": None} if self.audio_extractor is not None else {}
|
|
||||||
return {**super().get_supported_mm_limits(), **video_limit, **audio_limit}
|
|
||||||
|
|
||||||
def get_video_token(self) -> str | None:
|
|
||||||
return IMG_CONTEXT
|
|
||||||
|
|
||||||
def get_video_pruning_rate(self) -> float | None:
|
|
||||||
return self.ctx.get_mm_config().video_pruning_rate
|
|
||||||
|
|
||||||
def get_num_frames_with_most_features(
|
def get_num_frames_with_most_features(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
@@ -306,31 +303,12 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
|||||||
max_frames_per_video = max_tubelets_per_video * T
|
max_frames_per_video = max_tubelets_per_video * T
|
||||||
return max(max_frames_per_video, 1)
|
return max(max_frames_per_video, 1)
|
||||||
|
|
||||||
def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
|
|
||||||
return self.ctx.init_processor(
|
|
||||||
NanoNemotronVLProcessor,
|
|
||||||
config=self.get_hf_config(),
|
|
||||||
tokenizer=self.get_tokenizer(),
|
|
||||||
video_token=self.get_video_token(),
|
|
||||||
video_pruning_rate=self.get_video_pruning_rate(),
|
|
||||||
max_model_len=self.ctx.model_config.max_model_len,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
class NanoNemotronVLMultiModalProcessor(
|
||||||
class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
BaseMultiModalProcessor[NanoNemotronVLProcessingInfo]
|
||||||
"""Basic image-only MultiModalProcessor for InternVL-style models."""
|
):
|
||||||
|
def _get_image_fields_config(self, hf_inputs: BatchFeature):
|
||||||
@cached_property
|
if self.info.is_dynamic_tiler:
|
||||||
def is_dynamic_tiler(self) -> bool:
|
|
||||||
return self.info.get_hf_processor().dynamic_tiler is not None
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
|
||||||
self,
|
|
||||||
hf_inputs: BatchFeature,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
|
||||||
if self.is_dynamic_tiler:
|
|
||||||
pixel_values_flat = MultiModalFieldConfig.batched("image")
|
pixel_values_flat = MultiModalFieldConfig.batched("image")
|
||||||
else:
|
else:
|
||||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||||
@@ -346,15 +324,50 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
imgs_sizes=MultiModalFieldConfig.batched("image"),
|
imgs_sizes=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_video_fields_config(self, hf_inputs: BatchFeature):
|
||||||
|
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"video", video_num_patches
|
||||||
|
),
|
||||||
|
video_num_patches=MultiModalFieldConfig.batched("video"),
|
||||||
|
frames_indices=MultiModalFieldConfig.batched("video"),
|
||||||
|
frame_duration_ms=MultiModalFieldConfig.batched("video"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_audio_fields_config(self, hf_inputs: BatchFeature):
|
||||||
|
audio_num_clips = torch.as_tensor(hf_inputs["audio_num_clips"])
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"audio", audio_num_clips
|
||||||
|
),
|
||||||
|
feature_attention_mask=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"audio", audio_num_clips
|
||||||
|
),
|
||||||
|
audio_num_clips=MultiModalFieldConfig.batched("audio", keep_on_cpu=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
fields = self._get_image_fields_config(hf_inputs)
|
||||||
|
if self.info.supports_video:
|
||||||
|
fields |= self._get_video_fields_config(hf_inputs)
|
||||||
|
if self.info.audio_extractor:
|
||||||
|
fields |= self._get_audio_fields_config(hf_inputs)
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
def _get_prompt_repl_image(
|
||||||
self,
|
self,
|
||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor: NanoNemotronVLProcessor,
|
||||||
out_mm_kwargs: MultiModalKwargsItems,
|
out_mm_data: BatchedTensorInputs,
|
||||||
) -> Sequence[PromptUpdate]:
|
):
|
||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
||||||
|
|
||||||
out_mm_data = out_mm_kwargs.get_data()
|
|
||||||
if "image_num_patches" in out_mm_data:
|
if "image_num_patches" in out_mm_data:
|
||||||
image_num_patches = out_mm_data["image_num_patches"]
|
image_num_patches = out_mm_data["image_num_patches"]
|
||||||
assert isinstance(image_num_patches, torch.Tensor)
|
assert isinstance(image_num_patches, torch.Tensor)
|
||||||
@@ -365,7 +378,7 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
else:
|
else:
|
||||||
image_num_patches = []
|
image_num_patches = []
|
||||||
|
|
||||||
def get_replacement_custom(item_idx: int):
|
def get_image_replacement(item_idx: int):
|
||||||
images = mm_items.get_items(
|
images = mm_items.get_items(
|
||||||
"image", (ImageEmbeddingItems, ImageProcessorItems)
|
"image", (ImageEmbeddingItems, ImageProcessorItems)
|
||||||
)
|
)
|
||||||
@@ -377,10 +390,7 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
feature_size = tiler.get_cached_feature_size(image)
|
feature_size = tiler.get_cached_feature_size(image)
|
||||||
else:
|
else:
|
||||||
image_size = images.get_image_size(item_idx)
|
image_size = images.get_image_size(item_idx)
|
||||||
# Extract max_num_tiles from kwargs, default to 12
|
max_num_tiles = hf_processor.max_num_tiles
|
||||||
max_num_tiles = hf_processor_mm_kwargs.get(
|
|
||||||
"max_num_tiles", hf_processor.max_num_tiles
|
|
||||||
)
|
|
||||||
feature_size = hf_processor.get_num_image_tokens(
|
feature_size = hf_processor.get_num_image_tokens(
|
||||||
image_width=image_size.width,
|
image_width=image_size.width,
|
||||||
image_height=image_size.height,
|
image_height=image_size.height,
|
||||||
@@ -398,194 +408,18 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
|
|
||||||
return hf_processor.get_image_repl(feature_size, num_patches)
|
return hf_processor.get_image_repl(feature_size, num_patches)
|
||||||
|
|
||||||
return [
|
return PromptReplacement(
|
||||||
PromptReplacement(
|
|
||||||
modality="image",
|
modality="image",
|
||||||
target="<image>",
|
target="<image>",
|
||||||
replacement=get_replacement_custom,
|
replacement=get_image_replacement,
|
||||||
)
|
)
|
||||||
]
|
|
||||||
|
|
||||||
|
def _get_prompt_repl_video(
|
||||||
class NanoNemotronVLMultiModalProcessor(
|
|
||||||
NanoNemotronBaseVLMultiModalProcessor[NanoNemotronVLProcessingInfo]
|
|
||||||
):
|
|
||||||
"""MultiModalProcessor extended for video support"""
|
|
||||||
|
|
||||||
def _extract_audio_from_videos(
|
|
||||||
self,
|
self,
|
||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
) -> tuple[MultiModalDataItems, list[AudioItem]]:
|
hf_processor: NanoNemotronVLProcessor,
|
||||||
"""Extract audio tracks from video bytes in *mm_items*.
|
out_mm_data: BatchedTensorInputs,
|
||||||
|
|
||||||
Returns:
|
|
||||||
The augmented *mm_items* (with audio added) and the list of
|
|
||||||
extracted audio items.
|
|
||||||
"""
|
|
||||||
videos = mm_items.get_items("video", VideoProcessorItems)
|
|
||||||
assert isinstance(videos.metadata, list)
|
|
||||||
metadata_list = videos.metadata
|
|
||||||
|
|
||||||
audio_items: list[AudioItem] = []
|
|
||||||
for metadata in metadata_list:
|
|
||||||
video_bytes = metadata.get("original_video_bytes")
|
|
||||||
if video_bytes is None or len(video_bytes) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot extract audio from video: original_video_bytes is "
|
|
||||||
"missing or empty. When using use_audio_in_video=True, "
|
|
||||||
"video must be loaded with keep_video_bytes=True (e.g. via "
|
|
||||||
"the chat API with a model that sets use_audio_in_video)."
|
|
||||||
)
|
|
||||||
audio_items.append(extract_audio_from_video_bytes(video_bytes))
|
|
||||||
|
|
||||||
# Create a new VideoProcessorItems with metadata that does not contain
|
|
||||||
# the large video bytes, to avoid modifying the input `mm_items`.
|
|
||||||
new_metadata_list = [
|
|
||||||
{k: v for k, v in meta.items() if k != "original_video_bytes"}
|
|
||||||
for meta in metadata_list
|
|
||||||
]
|
|
||||||
new_videos = VideoProcessorItems(data=videos.data, metadata=new_metadata_list)
|
|
||||||
|
|
||||||
audio_parsed = self.data_parser.parse_mm_data({"audio": audio_items})
|
|
||||||
|
|
||||||
# Create a new MultiModalDataItems with the new video and audio items.
|
|
||||||
new_mm_items_dict = {**mm_items, **audio_parsed, "video": new_videos}
|
|
||||||
mm_items = MultiModalDataItems(new_mm_items_dict)
|
|
||||||
|
|
||||||
return mm_items, audio_items
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
processor_inputs: ProcessorInputs,
|
|
||||||
timing_ctx: TimingContext | None = None,
|
|
||||||
) -> MultiModalInputs:
|
|
||||||
if (hf_processor_mm_kwargs := processor_inputs.hf_processor_mm_kwargs) is None:
|
|
||||||
hf_processor_mm_kwargs = {}
|
|
||||||
|
|
||||||
use_audio_in_video = bool(
|
|
||||||
hf_processor_mm_kwargs.get("use_audio_in_video", False)
|
|
||||||
)
|
|
||||||
|
|
||||||
hf_processor_mm_kwargs = {
|
|
||||||
k: v for k, v in hf_processor_mm_kwargs.items() if k != "use_audio_in_video"
|
|
||||||
}
|
|
||||||
|
|
||||||
processor_inputs.hf_processor_mm_kwargs = hf_processor_mm_kwargs
|
|
||||||
|
|
||||||
if not (
|
|
||||||
use_audio_in_video
|
|
||||||
and "video" in processor_inputs.mm_data_items
|
|
||||||
and "audio" not in processor_inputs.mm_data_items
|
|
||||||
):
|
):
|
||||||
return super().apply(
|
|
||||||
processor_inputs,
|
|
||||||
timing_ctx,
|
|
||||||
)
|
|
||||||
|
|
||||||
mm_items, audio_items = self._extract_audio_from_videos(
|
|
||||||
processor_inputs.mm_data_items
|
|
||||||
)
|
|
||||||
processor_inputs.mm_data_items = mm_items
|
|
||||||
|
|
||||||
prompt = processor_inputs.prompt
|
|
||||||
tokenizer = self.info.get_tokenizer()
|
|
||||||
if not isinstance(prompt, str):
|
|
||||||
prompt = tokenizer.decode(prompt, skip_special_tokens=False)
|
|
||||||
|
|
||||||
for _ in audio_items:
|
|
||||||
prompt = prompt.replace("<video>", "<video>" + AUDIO_CONTEXT, 1)
|
|
||||||
|
|
||||||
processor_inputs.prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
|
||||||
|
|
||||||
if processor_inputs.tokenization_kwargs is None:
|
|
||||||
processor_inputs.tokenization_kwargs = {}
|
|
||||||
|
|
||||||
# Bypass the cached path: the HF processor must receive the
|
|
||||||
# prompt (with injected <so_embedding>) and the audio data
|
|
||||||
# together so it can perform audio-token replacement natively.
|
|
||||||
(
|
|
||||||
prompt_ids,
|
|
||||||
mm_info,
|
|
||||||
is_update_applied,
|
|
||||||
) = self._apply_hf_processor(
|
|
||||||
processor_inputs,
|
|
||||||
timing_ctx=timing_ctx,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
|
|
||||||
mm_items=mm_items,
|
|
||||||
prompt_ids=prompt_ids,
|
|
||||||
mm_kwargs=mm_info.kwargs,
|
|
||||||
mm_prompt_updates=mm_info.prompt_updates,
|
|
||||||
is_update_applied=is_update_applied,
|
|
||||||
)
|
|
||||||
|
|
||||||
mm_placeholder_ranges = {
|
|
||||||
modality: [item.to_range() for item in placeholders]
|
|
||||||
for modality, placeholders in mm_placeholders.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return MultiModalInputs(
|
|
||||||
type="multimodal",
|
|
||||||
prompt_token_ids=prompt_ids,
|
|
||||||
mm_kwargs=mm_info.kwargs,
|
|
||||||
mm_hashes=mm_info.hashes,
|
|
||||||
mm_placeholders=mm_placeholder_ranges,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
|
||||||
self,
|
|
||||||
hf_inputs: BatchFeature,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
|
||||||
image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
|
|
||||||
if self.info.supports_video:
|
|
||||||
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
|
|
||||||
|
|
||||||
video_fields = dict(
|
|
||||||
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
|
|
||||||
"video", video_num_patches
|
|
||||||
),
|
|
||||||
video_num_patches=MultiModalFieldConfig.batched("video"),
|
|
||||||
frames_indices=MultiModalFieldConfig.batched("video"),
|
|
||||||
frame_duration_ms=MultiModalFieldConfig.batched("video"),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
video_fields = {}
|
|
||||||
|
|
||||||
if self.info.audio_extractor is not None:
|
|
||||||
audio_num_clips = torch.as_tensor(hf_inputs["audio_num_clips"])
|
|
||||||
audio_fields = dict(
|
|
||||||
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
|
|
||||||
"audio", audio_num_clips
|
|
||||||
),
|
|
||||||
feature_attention_mask=MultiModalFieldConfig.flat_from_sizes(
|
|
||||||
"audio", audio_num_clips
|
|
||||||
),
|
|
||||||
audio_num_clips=MultiModalFieldConfig.batched(
|
|
||||||
"audio", keep_on_cpu=True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
audio_fields = {}
|
|
||||||
|
|
||||||
return image_fields | video_fields | audio_fields
|
|
||||||
|
|
||||||
def _get_prompt_updates(
|
|
||||||
self,
|
|
||||||
mm_items: MultiModalDataItems,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
out_mm_kwargs: MultiModalKwargsItems,
|
|
||||||
) -> Sequence[PromptUpdate]:
|
|
||||||
prompt_repl = super()._get_prompt_updates(
|
|
||||||
mm_items=mm_items,
|
|
||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
|
||||||
out_mm_kwargs=out_mm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
||||||
|
|
||||||
out_mm_data = out_mm_kwargs.get_data()
|
|
||||||
if "video_num_patches" in out_mm_data:
|
if "video_num_patches" in out_mm_data:
|
||||||
video_num_patches = out_mm_data["video_num_patches"]
|
video_num_patches = out_mm_data["video_num_patches"]
|
||||||
assert isinstance(video_num_patches, torch.Tensor)
|
assert isinstance(video_num_patches, torch.Tensor)
|
||||||
@@ -593,7 +427,7 @@ class NanoNemotronVLMultiModalProcessor(
|
|||||||
else:
|
else:
|
||||||
video_num_patches = []
|
video_num_patches = []
|
||||||
|
|
||||||
def get_video_replacement_internvl(item_idx: int):
|
def get_video_replacement(item_idx: int):
|
||||||
video, metadata = mm_items["video"][item_idx]
|
video, metadata = mm_items["video"][item_idx]
|
||||||
patch_size = hf_processor.config.patch_size
|
patch_size = hf_processor.config.patch_size
|
||||||
downsample_ratio = hf_processor.config.downsample_ratio
|
downsample_ratio = hf_processor.config.downsample_ratio
|
||||||
@@ -650,40 +484,206 @@ class NanoNemotronVLMultiModalProcessor(
|
|||||||
video_temporal_patch_size=T,
|
video_temporal_patch_size=T,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.info.supports_video:
|
return PromptReplacement(
|
||||||
prompt_repl = [
|
|
||||||
*prompt_repl,
|
|
||||||
PromptReplacement(
|
|
||||||
modality="video",
|
modality="video",
|
||||||
target="<video>",
|
target="<video>",
|
||||||
replacement=get_video_replacement_internvl,
|
replacement=get_video_replacement,
|
||||||
),
|
)
|
||||||
]
|
|
||||||
|
|
||||||
|
def _get_prompt_repl_audio(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor: NanoNemotronVLProcessor,
|
||||||
|
out_mm_data: BatchedTensorInputs,
|
||||||
|
):
|
||||||
def get_audio_replacement(item_idx: int):
|
def get_audio_replacement(item_idx: int):
|
||||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||||
return hf_processor.get_audio_repl(audios.get(item_idx))
|
return hf_processor.get_audio_repl(audios.get(item_idx))
|
||||||
|
|
||||||
if self.info.audio_extractor is not None:
|
return PromptReplacement(
|
||||||
prompt_repl = [
|
|
||||||
*prompt_repl,
|
|
||||||
PromptReplacement(
|
|
||||||
modality="audio",
|
modality="audio",
|
||||||
target=AUDIO_CONTEXT,
|
target=AUDIO_CONTEXT,
|
||||||
replacement=get_audio_replacement,
|
replacement=get_audio_replacement,
|
||||||
),
|
)
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargsItems,
|
||||||
|
) -> Sequence[PromptUpdate]:
|
||||||
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
out_mm_data = out_mm_kwargs.get_data()
|
||||||
|
|
||||||
|
prompt_repls = [
|
||||||
|
self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data),
|
||||||
]
|
]
|
||||||
|
if self.info.supports_video:
|
||||||
|
prompt_repls.append(
|
||||||
|
self._get_prompt_repl_video(mm_items, hf_processor, out_mm_data)
|
||||||
|
)
|
||||||
|
if self.info.audio_extractor:
|
||||||
|
prompt_repls.append(
|
||||||
|
self._get_prompt_repl_audio(mm_items, hf_processor, out_mm_data)
|
||||||
|
)
|
||||||
|
|
||||||
return prompt_repl
|
return prompt_repls
|
||||||
|
|
||||||
|
def _extract_audio_from_videos(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
) -> tuple[MultiModalDataItems, list[AudioItem]]:
|
||||||
|
"""Extract audio tracks from video bytes in *mm_items*.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The augmented *mm_items* (with audio added) and the list of
|
||||||
|
extracted audio items.
|
||||||
|
"""
|
||||||
|
videos = mm_items.get_items("video", VideoProcessorItems)
|
||||||
|
assert isinstance(videos.metadata, list)
|
||||||
|
metadata_list = videos.metadata
|
||||||
|
|
||||||
|
audio_items: list[AudioItem] = []
|
||||||
|
for metadata in metadata_list:
|
||||||
|
video_bytes = metadata.get("original_video_bytes")
|
||||||
|
if video_bytes is None or len(video_bytes) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot extract audio from video: original_video_bytes is "
|
||||||
|
"missing or empty. When using use_audio_in_video=True, "
|
||||||
|
"video must be loaded with keep_video_bytes=True (e.g. via "
|
||||||
|
"the chat API with a model that sets use_audio_in_video)."
|
||||||
|
)
|
||||||
|
audio_items.append(extract_audio_from_video_bytes(video_bytes))
|
||||||
|
|
||||||
|
# Create a new VideoProcessorItems with metadata that does not contain
|
||||||
|
# the large video bytes, to avoid modifying the input `mm_items`.
|
||||||
|
new_metadata_list = [
|
||||||
|
{k: v for k, v in meta.items() if k != "original_video_bytes"}
|
||||||
|
for meta in metadata_list
|
||||||
|
]
|
||||||
|
new_videos = VideoProcessorItems(data=videos.data, metadata=new_metadata_list)
|
||||||
|
|
||||||
|
audio_parsed = self.data_parser.parse_mm_data({"audio": audio_items})
|
||||||
|
|
||||||
|
# Create a new MultiModalDataItems with the new video and audio items.
|
||||||
|
new_mm_items_dict = {**mm_items, **audio_parsed, "video": new_videos}
|
||||||
|
mm_items = MultiModalDataItems(new_mm_items_dict)
|
||||||
|
|
||||||
|
return mm_items, audio_items
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
inputs: ProcessorInputs,
|
||||||
|
timing_ctx: TimingContext,
|
||||||
|
) -> MultiModalInputs:
|
||||||
|
use_audio_in_video = bool(
|
||||||
|
inputs.hf_processor_mm_kwargs.get("use_audio_in_video", False)
|
||||||
|
)
|
||||||
|
inputs.hf_processor_mm_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in inputs.hf_processor_mm_kwargs.items()
|
||||||
|
if k != "use_audio_in_video"
|
||||||
|
}
|
||||||
|
|
||||||
|
if not (
|
||||||
|
use_audio_in_video
|
||||||
|
and "video" in inputs.mm_data_items
|
||||||
|
and "audio" not in inputs.mm_data_items
|
||||||
|
):
|
||||||
|
return super().apply(inputs, timing_ctx)
|
||||||
|
|
||||||
|
mm_items, audio_items = self._extract_audio_from_videos(inputs.mm_data_items)
|
||||||
|
inputs.mm_data_items = mm_items
|
||||||
|
|
||||||
|
prompt = inputs.prompt
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
if not isinstance(prompt, str):
|
||||||
|
prompt = tokenizer.decode(prompt, skip_special_tokens=False)
|
||||||
|
|
||||||
|
for _ in audio_items:
|
||||||
|
prompt = prompt.replace("<video>", "<video>" + AUDIO_CONTEXT, 1)
|
||||||
|
|
||||||
|
inputs.prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
|
|
||||||
|
if inputs.tokenization_kwargs is None:
|
||||||
|
inputs.tokenization_kwargs = {}
|
||||||
|
|
||||||
|
# Bypass the cached path: the HF processor must receive the
|
||||||
|
# prompt (with injected <so_embedding>) and the audio data
|
||||||
|
# together so it can perform audio-token replacement natively.
|
||||||
|
(
|
||||||
|
prompt_ids,
|
||||||
|
mm_info,
|
||||||
|
is_update_applied,
|
||||||
|
) = self._apply_hf_processor(inputs, timing_ctx)
|
||||||
|
|
||||||
|
with timing_ctx.record("apply_prompt_updates"):
|
||||||
|
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
|
||||||
|
mm_items=mm_items,
|
||||||
|
prompt_ids=prompt_ids,
|
||||||
|
mm_kwargs=mm_info.kwargs,
|
||||||
|
mm_prompt_updates=mm_info.prompt_updates,
|
||||||
|
is_update_applied=is_update_applied,
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_placeholder_ranges = {
|
||||||
|
modality: [item.to_range() for item in placeholders]
|
||||||
|
for modality, placeholders in mm_placeholders.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return MultiModalInputs(
|
||||||
|
type="multimodal",
|
||||||
|
prompt_token_ids=prompt_ids,
|
||||||
|
mm_kwargs=mm_info.kwargs,
|
||||||
|
mm_hashes=mm_info.hashes,
|
||||||
|
mm_placeholders=mm_placeholder_ranges,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
class NanoNemotronVLDummyInputsBuilder(
|
||||||
"""Basic image-only DummyInputsBuilder for InternVL-style models."""
|
BaseDummyInputsBuilder[NanoNemotronVLProcessingInfo]
|
||||||
|
):
|
||||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
|
num_videos = mm_counts.get("video", 0)
|
||||||
|
num_audios = mm_counts.get("audio", 0)
|
||||||
|
|
||||||
return "<image>" * num_images
|
return (
|
||||||
|
"<image>" * num_images + "<video>" * num_videos + AUDIO_CONTEXT * num_audios
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_dummy_videos(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
num_frames: int,
|
||||||
|
num_videos: int,
|
||||||
|
overrides: VideoDummyOptions | None = None,
|
||||||
|
) -> list[VideoItem]:
|
||||||
|
videos = super()._get_dummy_videos(
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
num_frames=num_frames,
|
||||||
|
num_videos=num_videos,
|
||||||
|
overrides=overrides,
|
||||||
|
)
|
||||||
|
videos = [v.copy() for v in videos]
|
||||||
|
|
||||||
|
video_items = []
|
||||||
|
for video in videos:
|
||||||
|
video_num_frames = video.shape[0]
|
||||||
|
video_metadata = {
|
||||||
|
"fps": 2,
|
||||||
|
"duration": video_num_frames / 2.0,
|
||||||
|
"total_num_frames": video_num_frames,
|
||||||
|
"frames_indices": list(range(video_num_frames)),
|
||||||
|
"video_backend": "opencv_dynamic",
|
||||||
|
"do_sample_frames": False,
|
||||||
|
}
|
||||||
|
video_items.append((video, video_metadata))
|
||||||
|
|
||||||
|
return video_items
|
||||||
|
|
||||||
def get_dummy_mm_data(
|
def get_dummy_mm_data(
|
||||||
self,
|
self,
|
||||||
@@ -706,7 +706,7 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
|||||||
|
|
||||||
image_overrides = mm_options.get("image")
|
image_overrides = mm_options.get("image")
|
||||||
|
|
||||||
return {
|
dummy_image = {
|
||||||
"image": self._get_dummy_images(
|
"image": self._get_dummy_images(
|
||||||
width=target_width,
|
width=target_width,
|
||||||
height=target_height,
|
height=target_height,
|
||||||
@@ -715,64 +715,9 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class NanoNemotronVLDummyInputsBuilder(
|
|
||||||
NanoNemotronVLDummyInputsBuilder[NanoNemotronVLProcessingInfo]
|
|
||||||
):
|
|
||||||
"""DummyInputsBuilder extended for video support"""
|
|
||||||
|
|
||||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
||||||
num_videos = mm_counts.get("video", 0)
|
|
||||||
num_audios = mm_counts.get("audio", 0)
|
|
||||||
|
|
||||||
return (
|
|
||||||
super().get_dummy_text(mm_counts)
|
|
||||||
+ "<video>" * num_videos
|
|
||||||
+ AUDIO_CONTEXT * num_audios
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_dummy_videos(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
num_frames: int,
|
|
||||||
num_videos: int,
|
|
||||||
overrides: VideoDummyOptions | None = None,
|
|
||||||
) -> list[VideoItem]:
|
|
||||||
video = super()._get_dummy_videos(
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
num_frames=num_frames,
|
|
||||||
num_videos=1,
|
|
||||||
overrides=overrides,
|
|
||||||
)[0]
|
|
||||||
video_items = []
|
|
||||||
for _ in range(num_videos):
|
|
||||||
video_metadata = {
|
|
||||||
"total_num_frames": num_frames,
|
|
||||||
"fps": 2,
|
|
||||||
"duration": num_frames / 2.0,
|
|
||||||
"video_backend": "opencv_dynamic",
|
|
||||||
"frames_indices": [i for i in range(num_frames)],
|
|
||||||
"do_sample_frames": False,
|
|
||||||
}
|
|
||||||
video_item = (video.copy(), video_metadata)
|
|
||||||
video_items.append(video_item)
|
|
||||||
|
|
||||||
return video_items
|
|
||||||
|
|
||||||
def get_dummy_mm_data(
|
|
||||||
self,
|
|
||||||
seq_len: int,
|
|
||||||
mm_counts: Mapping[str, int],
|
|
||||||
mm_options: Mapping[str, BaseDummyOptions],
|
|
||||||
) -> MultiModalDataDict:
|
|
||||||
dummy_image = super().get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
|
||||||
if self.info.supports_video:
|
if self.info.supports_video:
|
||||||
config = self.info.get_hf_config()
|
config = self.info.get_hf_config()
|
||||||
image_size: int = config.force_image_size
|
image_size: int = config.force_image_size
|
||||||
processor = self.info.get_hf_processor()
|
|
||||||
|
|
||||||
# When video_target_num_patches is set the per-frame pixel
|
# When video_target_num_patches is set the per-frame pixel
|
||||||
# resolution can exceed image_size. Use the actual target
|
# resolution can exceed image_size. Use the actual target
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
# Copyright (c) 2024 NVIDIA
|
# Copyright (c) 2024 NVIDIA
|
||||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -16,7 +16,10 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
|
from vllm.multimodal.inputs import (
|
||||||
|
BatchedTensorInputs,
|
||||||
|
MultiModalDataDict,
|
||||||
|
)
|
||||||
from vllm.multimodal.parse import (
|
from vllm.multimodal.parse import (
|
||||||
ImageEmbeddingItems,
|
ImageEmbeddingItems,
|
||||||
ImageProcessorItems,
|
ImageProcessorItems,
|
||||||
@@ -24,7 +27,6 @@ from vllm.multimodal.parse import (
|
|||||||
)
|
)
|
||||||
from vllm.multimodal.processing import (
|
from vllm.multimodal.processing import (
|
||||||
PromptReplacement,
|
PromptReplacement,
|
||||||
PromptUpdate,
|
|
||||||
PromptUpdateDetails,
|
PromptUpdateDetails,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.processors.internvl import InternVLImageProcessor
|
from vllm.transformers_utils.processors.internvl import InternVLImageProcessor
|
||||||
@@ -100,15 +102,12 @@ class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo])
|
|||||||
|
|
||||||
|
|
||||||
class NVLMMultiModalProcessor(BaseInternVLMultiModalProcessor[NVLMProcessingInfo]):
|
class NVLMMultiModalProcessor(BaseInternVLMultiModalProcessor[NVLMProcessingInfo]):
|
||||||
def _get_prompt_updates(
|
def _get_prompt_repl_image(
|
||||||
self,
|
self,
|
||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor: NVLMProcessor,
|
||||||
out_mm_kwargs: MultiModalKwargsItems,
|
out_mm_data: BatchedTensorInputs,
|
||||||
) -> Sequence[PromptUpdate]:
|
):
|
||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
||||||
|
|
||||||
out_mm_data = out_mm_kwargs.get_data()
|
|
||||||
if "image_num_patches" in out_mm_data:
|
if "image_num_patches" in out_mm_data:
|
||||||
image_num_patches = out_mm_data["image_num_patches"]
|
image_num_patches = out_mm_data["image_num_patches"]
|
||||||
assert isinstance(image_num_patches, torch.Tensor)
|
assert isinstance(image_num_patches, torch.Tensor)
|
||||||
@@ -146,13 +145,11 @@ class NVLMMultiModalProcessor(BaseInternVLMultiModalProcessor[NVLMProcessingInfo
|
|||||||
)
|
)
|
||||||
|
|
||||||
# See note in dummy data regarding why we have the extra newline
|
# See note in dummy data regarding why we have the extra newline
|
||||||
return [
|
return PromptReplacement(
|
||||||
PromptReplacement(
|
|
||||||
modality="image",
|
modality="image",
|
||||||
target="<image>\n",
|
target="<image>\n",
|
||||||
replacement=get_replacement_nvlm,
|
replacement=get_replacement_nvlm,
|
||||||
)
|
)
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
|||||||
@@ -931,20 +931,30 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
|
|||||||
height: int,
|
height: int,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
num_videos: int,
|
num_videos: int,
|
||||||
|
overrides: VideoDummyOptions | None = None,
|
||||||
) -> list[VideoItem]:
|
) -> list[VideoItem]:
|
||||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
videos = super()._get_dummy_videos(
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
num_frames=num_frames,
|
||||||
|
num_videos=num_videos,
|
||||||
|
overrides=overrides,
|
||||||
|
)
|
||||||
|
videos = [v.copy() for v in videos]
|
||||||
|
|
||||||
video_items = []
|
video_items = []
|
||||||
for i in range(num_videos):
|
for video in videos:
|
||||||
|
video_num_frames = video.shape[0]
|
||||||
video_metadata = {
|
video_metadata = {
|
||||||
"fps": 2.0,
|
"fps": 2.0,
|
||||||
"duration": num_frames / 2.0,
|
"duration": video_num_frames / 2.0,
|
||||||
"total_num_frames": num_frames,
|
"total_num_frames": video_num_frames,
|
||||||
"frames_indices": [i for i in range(num_frames)],
|
"frames_indices": list(range(video_num_frames)),
|
||||||
"video_backend": "opencv",
|
"video_backend": "opencv",
|
||||||
"do_sample_frames": False,
|
"do_sample_frames": False,
|
||||||
}
|
}
|
||||||
video_item = (video.copy(), video_metadata)
|
video_items.append((video, video_metadata))
|
||||||
video_items.append(video_item)
|
|
||||||
return video_items
|
return video_items
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,12 @@
|
|||||||
# Copyright (c) 2025 Skywork
|
# Copyright (c) 2025 Skywork
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping
|
||||||
from typing import Annotated, Literal, TypeAlias
|
from typing import Annotated, Literal, TypeAlias
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BatchFeature, PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
@@ -24,24 +24,8 @@ from vllm.model_executor.models.intern_vit import (
|
|||||||
InternVisionPatchModel,
|
InternVisionPatchModel,
|
||||||
)
|
)
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (
|
from vllm.multimodal.inputs import MultiModalDataDict
|
||||||
MultiModalDataDict,
|
from vllm.multimodal.processing import BaseDummyInputsBuilder
|
||||||
MultiModalFieldConfig,
|
|
||||||
MultiModalKwargsItems,
|
|
||||||
)
|
|
||||||
from vllm.multimodal.parse import (
|
|
||||||
ImageEmbeddingItems,
|
|
||||||
ImageProcessorItems,
|
|
||||||
ImageSize,
|
|
||||||
MultiModalDataItems,
|
|
||||||
)
|
|
||||||
from vllm.multimodal.processing import (
|
|
||||||
BaseDummyInputsBuilder,
|
|
||||||
BaseMultiModalProcessor,
|
|
||||||
BaseProcessingInfo,
|
|
||||||
PromptReplacement,
|
|
||||||
PromptUpdate,
|
|
||||||
)
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.processors.internvl import (
|
from vllm.transformers_utils.processors.internvl import (
|
||||||
InternVLImageProcessor,
|
InternVLImageProcessor,
|
||||||
@@ -50,6 +34,11 @@ from vllm.transformers_utils.processors.internvl import (
|
|||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
|
from .internvl import (
|
||||||
|
BaseInternVLDummyInputsBuilder,
|
||||||
|
BaseInternVLMultiModalProcessor,
|
||||||
|
BaseInternVLProcessingInfo,
|
||||||
|
)
|
||||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
@@ -98,7 +87,7 @@ SkyworkR1VImageInputs: TypeAlias = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SkyworkR1VProcessingInfo(BaseProcessingInfo):
|
class SkyworkR1VProcessingInfo(BaseInternVLProcessingInfo):
|
||||||
def get_image_processor(self, **kwargs):
|
def get_image_processor(self, **kwargs):
|
||||||
config = self.get_hf_config()
|
config = self.get_hf_config()
|
||||||
vision_config = config.vision_config
|
vision_config = config.vision_config
|
||||||
@@ -128,46 +117,6 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo):
|
|||||||
image_seq_length=image_seq_length,
|
image_seq_length=image_seq_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
||||||
return {"image": None}
|
|
||||||
|
|
||||||
def get_num_image_tokens(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
image_width: int,
|
|
||||||
image_height: int,
|
|
||||||
processor: InternVLProcessor,
|
|
||||||
) -> int:
|
|
||||||
return processor.get_num_image_tokens(
|
|
||||||
image_width=image_width,
|
|
||||||
image_height=image_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_image_size_with_most_features(self) -> ImageSize:
|
|
||||||
processor = self.get_hf_processor()
|
|
||||||
image_processor = processor.image_processor
|
|
||||||
|
|
||||||
base_size = image_processor.image_size
|
|
||||||
target_ratios = processor.resolve_target_ratios()
|
|
||||||
|
|
||||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
|
||||||
for wr, hr in target_ratios:
|
|
||||||
width, height = base_size * wr, base_size * hr
|
|
||||||
|
|
||||||
feat_size = self.get_num_image_tokens(
|
|
||||||
image_width=width,
|
|
||||||
image_height=height,
|
|
||||||
processor=processor,
|
|
||||||
)
|
|
||||||
if feat_size > largest_feature_size:
|
|
||||||
largest_feature_size = feat_size
|
|
||||||
largest_feature_pinpoint = ImageSize(width=width, height=height)
|
|
||||||
|
|
||||||
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
|
||||||
raise ValueError("Cannot have a largest feature size of 0!")
|
|
||||||
|
|
||||||
return largest_feature_pinpoint
|
|
||||||
|
|
||||||
|
|
||||||
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]):
|
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]):
|
||||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
@@ -196,102 +145,10 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingIn
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[SkyworkR1VProcessingInfo]):
|
|
||||||
def _call_hf_processor(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
mm_data: Mapping[str, object],
|
|
||||||
mm_kwargs: Mapping[str, object],
|
|
||||||
tok_kwargs: Mapping[str, object],
|
|
||||||
) -> BatchFeature:
|
|
||||||
processed_outputs = super()._call_hf_processor(
|
|
||||||
prompt=prompt,
|
|
||||||
mm_data=mm_data,
|
|
||||||
mm_kwargs=mm_kwargs,
|
|
||||||
tok_kwargs=tok_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
|
||||||
image_token_id = hf_processor.ctx_image_token_id
|
|
||||||
|
|
||||||
# Since there may be extra tokens in the feature placeholders,
|
|
||||||
# we need to pass the image token ID to the model to select the
|
|
||||||
# tokens to merge from the vision encoder outputs
|
|
||||||
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
|
|
||||||
|
|
||||||
return processed_outputs
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
|
||||||
self,
|
|
||||||
hf_inputs: BatchFeature,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
|
||||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
|
||||||
num_images = len(image_num_patches)
|
|
||||||
|
|
||||||
return dict(
|
|
||||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
|
||||||
"image", image_num_patches
|
|
||||||
),
|
|
||||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
|
||||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_prompt_updates(
|
|
||||||
self,
|
|
||||||
mm_items: MultiModalDataItems,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
out_mm_kwargs: MultiModalKwargsItems,
|
|
||||||
) -> Sequence[PromptUpdate]:
|
|
||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
||||||
|
|
||||||
out_mm_data = out_mm_kwargs.get_data()
|
|
||||||
if "image_num_patches" in out_mm_data:
|
|
||||||
image_num_patches = out_mm_data["image_num_patches"]
|
|
||||||
assert isinstance(image_num_patches, torch.Tensor)
|
|
||||||
image_num_patches = image_num_patches.tolist()
|
|
||||||
elif "image_embeds" in out_mm_data:
|
|
||||||
# TODO: Use image size information in dictionary embedding inputs
|
|
||||||
# to compute num_patches (similar to Qwen2-VL)
|
|
||||||
image_num_patches = [None] * len(out_mm_data["image_embeds"])
|
|
||||||
else:
|
|
||||||
image_num_patches = []
|
|
||||||
|
|
||||||
def get_replacement_skyworkr1v(item_idx: int):
|
|
||||||
images = mm_items.get_items(
|
|
||||||
"image", (ImageEmbeddingItems, ImageProcessorItems)
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(images, ImageEmbeddingItems):
|
|
||||||
feature_size = images.get_feature_size(item_idx)
|
|
||||||
else:
|
|
||||||
image_size = images.get_image_size(item_idx)
|
|
||||||
feature_size = self.info.get_num_image_tokens(
|
|
||||||
image_width=image_size.width,
|
|
||||||
image_height=image_size.height,
|
|
||||||
processor=hf_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_patches = image_num_patches[item_idx]
|
|
||||||
if num_patches is not None:
|
|
||||||
assert isinstance(num_patches, int)
|
|
||||||
|
|
||||||
return hf_processor.get_image_repl(num_patches, num_features=feature_size)
|
|
||||||
|
|
||||||
return [
|
|
||||||
PromptReplacement(
|
|
||||||
modality="image",
|
|
||||||
target="<image>",
|
|
||||||
replacement=get_replacement_skyworkr1v,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
SkyworkR1VMultiModalProcessor,
|
BaseInternVLMultiModalProcessor,
|
||||||
info=SkyworkR1VProcessingInfo,
|
info=SkyworkR1VProcessingInfo,
|
||||||
dummy_inputs=SkyworkR1VDummyInputsBuilder,
|
dummy_inputs=BaseInternVLDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user