[Model][VLM] Add multi-video support for LLaVA-Onevision (#8905)
Co-authored-by: litianjian <litianjian@bytedance.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -43,19 +43,17 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
_MAX_NUM_VIDEOS = 1
|
||||
|
||||
|
||||
class LlavaOnevisionVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size, num_frames, num_channels, height, width)`
|
||||
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)`
|
||||
|
||||
Note that `num_frames` may be different for each batch, in which case
|
||||
the data is passed as a list instead of a batched tensor.
|
||||
|
||||
Note that it only supports one video input for one batch.
|
||||
Note that `num_videos` may be different for each batch, and 'num_frames'
|
||||
may be different for each video, in which case the data is passed as a
|
||||
list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
|
||||
@@ -213,11 +211,7 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
|
||||
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# TODO: support multiple videos
|
||||
num_videos = mm_counts["video"]
|
||||
if num_videos > _MAX_NUM_VIDEOS:
|
||||
raise NotImplementedError(
|
||||
f"Only {_MAX_NUM_VIDEOS} videos are supported")
|
||||
|
||||
# TODO: support configuring the number of frames
|
||||
num_frames = _MAX_FRAMES_PER_VIDEO
|
||||
@@ -232,7 +226,9 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
|
||||
image_feature_size_override=video_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames)
|
||||
mm_data = dummy_video_for_clip(vision_config,
|
||||
num_frames=num_frames,
|
||||
num_videos=num_videos)
|
||||
return seq_data, mm_data
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data = dummy_seq_data_for_siglip(
|
||||
@@ -243,7 +239,9 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
|
||||
image_feature_size_override=video_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames)
|
||||
mm_data = dummy_video_for_siglip(vision_config,
|
||||
num_frames=num_frames,
|
||||
num_videos=num_videos)
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
@@ -315,7 +313,6 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(video_data, np.ndarray):
|
||||
# Supports both CLIP and Siglip
|
||||
@@ -336,10 +333,27 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
elif is_list_of(video_data, np.ndarray):
|
||||
raise NotImplementedError(
|
||||
"Processing multiple videos is not supported")
|
||||
video_feature_size = []
|
||||
for video in video_data:
|
||||
num_frames = video.shape[0]
|
||||
video_feature_size.append(
|
||||
get_llava_onevision_video_tokens(ctx, num_frames))
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=hf_config.video_token_index,
|
||||
repeat_count=video_feature_size,
|
||||
)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
else:
|
||||
raise TypeError(f"Invalid video type: {type(video_data)}")
|
||||
|
||||
msg = f"Unsupported video type: {type(video_data)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@@ -723,6 +737,22 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
for i, patch_features_batch in enumerate(patch_embeddings)
|
||||
]
|
||||
|
||||
def _add_image_newline(
|
||||
self,
|
||||
video_features: torch.Tensor,
|
||||
videos: int = 1,
|
||||
frames: int = 1,
|
||||
strategy: str = "one_token",
|
||||
) -> torch.Tensor:
|
||||
if strategy == "one_token":
|
||||
video_features = video_features.reshape(
|
||||
videos, frames * video_features.shape[1], -1)
|
||||
image_newline = self.image_newline[None, None, :].repeat(
|
||||
videos, 1, 1).to(video_features.device)
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
return video_features
|
||||
raise ValueError(f"Unexpected video newline strategy: {strategy}")
|
||||
|
||||
def _video_pixels_to_features(
|
||||
self,
|
||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
||||
@@ -731,9 +761,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
b, num_videos, frames, c, h, w = pixel_values.shape
|
||||
assert (num_videos == _MAX_NUM_VIDEOS)
|
||||
pixel_values = pixel_values.reshape(b * num_videos * frames, c, h, w)
|
||||
video_features = vision_tower(pixel_values)
|
||||
video_features = self._select_image_features(
|
||||
video_features,
|
||||
@@ -741,13 +768,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
video_features = self.multi_modal_projector(video_features)
|
||||
video_features = self.apply_pooling(video_features)
|
||||
video_features = video_features.reshape(
|
||||
b, frames * video_features.shape[1], -1)
|
||||
image_newline = self.image_newline[None, None, :].repeat(b, 1, 1).to(
|
||||
video_features.device)
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
|
||||
return video_features
|
||||
|
||||
def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
|
||||
@@ -755,10 +775,28 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
video_pixels = inputs["data"]
|
||||
|
||||
# TODO: support multiple videos per input
|
||||
if isinstance(video_pixels, torch.Tensor):
|
||||
b, num_videos, frames, c, h, w = video_pixels.shape
|
||||
pixel_values = video_pixels.view(b * num_videos * frames, c, h, w)
|
||||
stacked_embeddings = self._video_pixels_to_features(
|
||||
self.vision_tower, video_pixels)
|
||||
self.vision_tower, pixel_values)
|
||||
stacked_embeddings = self._add_image_newline(stacked_embeddings,
|
||||
videos=b * num_videos,
|
||||
frames=frames,
|
||||
strategy="one_token")
|
||||
return stacked_embeddings
|
||||
elif is_list_of(video_pixels, torch.Tensor):
|
||||
stacked_embeddings = []
|
||||
for video_pixel in video_pixels:
|
||||
num_videos, frames, c, h, w = video_pixel.shape
|
||||
pixel_values = video_pixel.view(num_videos * frames, c, h, w)
|
||||
embeddings = self._video_pixels_to_features(
|
||||
self.vision_tower, pixel_values)
|
||||
embeddings = self._add_image_newline(embeddings,
|
||||
videos=num_videos,
|
||||
frames=frames,
|
||||
strategy="one_token")
|
||||
stacked_embeddings.append(embeddings)
|
||||
return stacked_embeddings
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user