[Model][VLM] Add multi-video support for LLaVA-Onevision (#8905)

Co-authored-by: litianjian <litianjian@bytedance.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
litianjian
2024-10-29 02:04:10 +08:00
committed by GitHub
parent 8b0e4f2ad7
commit 5f8d8075f9
5 changed files with 123 additions and 162 deletions

View File

@@ -43,19 +43,17 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
_MAX_NUM_VIDEOS = 1
class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, num_frames, num_channels, height, width)`
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)`
Note that `num_frames` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch.
Note that `num_videos` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a
list instead of a batched tensor.
"""
@@ -213,11 +211,7 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
# TODO: support multiple videos
num_videos = mm_counts["video"]
if num_videos > _MAX_NUM_VIDEOS:
raise NotImplementedError(
f"Only {_MAX_NUM_VIDEOS} videos are supported")
# TODO: support configuring the number of frames
num_frames = _MAX_FRAMES_PER_VIDEO
@@ -232,7 +226,9 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
image_feature_size_override=video_feature_size,
)
mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames)
mm_data = dummy_video_for_clip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return seq_data, mm_data
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
@@ -243,7 +239,9 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
image_feature_size_override=video_feature_size,
)
mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames)
mm_data = dummy_video_for_siglip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
@@ -315,7 +313,6 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
if isinstance(video_data, np.ndarray):
# Supports both CLIP and Siglip
@@ -336,10 +333,27 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
multi_modal_data=multi_modal_data)
elif is_list_of(video_data, np.ndarray):
raise NotImplementedError(
"Processing multiple videos is not supported")
video_feature_size = []
for video in video_data:
num_frames = video.shape[0]
video_feature_size.append(
get_llava_onevision_video_tokens(ctx, num_frames))
msg = f"Unsupported vision config: {type(vision_config)}"
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
else:
raise TypeError(f"Invalid video type: {type(video_data)}")
msg = f"Unsupported video type: {type(video_data)}"
raise NotImplementedError(msg)
@@ -723,6 +737,22 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings)
]
def _add_image_newline(
self,
video_features: torch.Tensor,
videos: int = 1,
frames: int = 1,
strategy: str = "one_token",
) -> torch.Tensor:
if strategy == "one_token":
video_features = video_features.reshape(
videos, frames * video_features.shape[1], -1)
image_newline = self.image_newline[None, None, :].repeat(
videos, 1, 1).to(video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
return video_features
raise ValueError(f"Unexpected video newline strategy: {strategy}")
def _video_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
@@ -731,9 +761,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
b, num_videos, frames, c, h, w = pixel_values.shape
assert (num_videos == _MAX_NUM_VIDEOS)
pixel_values = pixel_values.reshape(b * num_videos * frames, c, h, w)
video_features = vision_tower(pixel_values)
video_features = self._select_image_features(
video_features,
@@ -741,13 +768,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
)
video_features = self.multi_modal_projector(video_features)
video_features = self.apply_pooling(video_features)
video_features = video_features.reshape(
b, frames * video_features.shape[1], -1)
image_newline = self.image_newline[None, None, :].repeat(b, 1, 1).to(
video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)
return video_features
def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
@@ -755,10 +775,28 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
video_pixels = inputs["data"]
# TODO: support multiple videos per input
if isinstance(video_pixels, torch.Tensor):
b, num_videos, frames, c, h, w = video_pixels.shape
pixel_values = video_pixels.view(b * num_videos * frames, c, h, w)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, video_pixels)
self.vision_tower, pixel_values)
stacked_embeddings = self._add_image_newline(stacked_embeddings,
videos=b * num_videos,
frames=frames,
strategy="one_token")
return stacked_embeddings
elif is_list_of(video_pixels, torch.Tensor):
stacked_embeddings = []
for video_pixel in video_pixels:
num_videos, frames, c, h, w = video_pixel.shape
pixel_values = video_pixel.view(num_videos * frames, c, h, w)
embeddings = self._video_pixels_to_features(
self.vision_tower, pixel_values)
embeddings = self._add_image_newline(embeddings,
videos=num_videos,
frames=frames,
strategy="one_token")
stacked_embeddings.append(embeddings)
return stacked_embeddings
else:
raise ValueError(