238 lines
7.8 KiB
Python
238 lines
7.8 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
import dataclasses
|
||
|
|
import random
|
||
|
|
from dataclasses import dataclass
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
||
|
|
from vllm.multimodal.inputs import (
|
||
|
|
MultiModalFeatureSpec,
|
||
|
|
MultiModalFieldElem,
|
||
|
|
MultiModalKwargsItem,
|
||
|
|
PlaceholderRange,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture(autouse=True, scope="module")
|
||
|
|
def _force_cpu_default_device():
|
||
|
|
# _get_mrope_input_positions returns CPU tensors (via torch.from_numpy).
|
||
|
|
# Ensure the default device is CPU so the rest of the test tensors match.
|
||
|
|
original = torch.get_default_device()
|
||
|
|
torch.set_default_device("cpu")
|
||
|
|
yield
|
||
|
|
torch.set_default_device(original)
|
||
|
|
|
||
|
|
|
||
|
|
IMAGE_TOKEN_ID = 999
|
||
|
|
VIDEO_TOKEN_ID = 888
|
||
|
|
VISION_START_TOKEN_ID = 777
|
||
|
|
VISION_END_TOKEN_ID = 778
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class DummyVisionConfig:
|
||
|
|
spatial_merge_size: int = 1
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class DummyConfig:
|
||
|
|
image_token_id: int = IMAGE_TOKEN_ID
|
||
|
|
video_token_id: int = VIDEO_TOKEN_ID
|
||
|
|
vision_start_token_id: int = VISION_START_TOKEN_ID
|
||
|
|
vision_end_token_id: int = VISION_END_TOKEN_ID
|
||
|
|
vision_config: DummyVisionConfig = dataclasses.field(
|
||
|
|
default_factory=DummyVisionConfig
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def make_video_embedding(
|
||
|
|
t, h, w, interleave_text_tokens: tuple[int, int], video_pruning_rate: float = 0.0
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Helper function to make a video embedding for a given video size and pruning rate.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
t: Number of frames.
|
||
|
|
h: Number of rows.
|
||
|
|
w: Number of columns.
|
||
|
|
interleave_text_tokens: Tuple of minimum and maximum number of text tokens to
|
||
|
|
interleave with the video.
|
||
|
|
video_pruning_rate: Pruning rate for the video.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple of (unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask)
|
||
|
|
"""
|
||
|
|
unpruned_tokens_sequence = []
|
||
|
|
population = list(range(1, 100))
|
||
|
|
|
||
|
|
for _ in range(t):
|
||
|
|
num_prefix_tokens = random.randint(
|
||
|
|
interleave_text_tokens[0], interleave_text_tokens[1]
|
||
|
|
)
|
||
|
|
|
||
|
|
prefix_tokens = random.choices(population, k=num_prefix_tokens)
|
||
|
|
vision_tokens = (
|
||
|
|
[VISION_START_TOKEN_ID] + [VIDEO_TOKEN_ID] * h * w + [VISION_END_TOKEN_ID]
|
||
|
|
)
|
||
|
|
|
||
|
|
unpruned_tokens_sequence.extend(prefix_tokens)
|
||
|
|
unpruned_tokens_sequence.extend(vision_tokens)
|
||
|
|
|
||
|
|
unpruned_tokens_sequence = torch.tensor(unpruned_tokens_sequence, dtype=torch.long)
|
||
|
|
video_token_mask = unpruned_tokens_sequence == VIDEO_TOKEN_ID
|
||
|
|
|
||
|
|
pruning_mask = torch.bernoulli(video_token_mask.float() * video_pruning_rate).bool() # type: ignore[attr-defined]
|
||
|
|
# Sanity check that we don't prune what should not be pruned.
|
||
|
|
assert not pruning_mask[~video_token_mask].any()
|
||
|
|
|
||
|
|
retention_mask = ~pruning_mask
|
||
|
|
pruned_tokens_sequence = unpruned_tokens_sequence[retention_mask]
|
||
|
|
return unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("spatial_merge_size", [1, 2])
|
||
|
|
@pytest.mark.parametrize("grid_thw", [[3, 8, 7], [128, 10, 12]])
|
||
|
|
@pytest.mark.parametrize("num_prefix_tokens", [1, 11])
|
||
|
|
@pytest.mark.parametrize("num_suffix_tokens", [0, 7])
|
||
|
|
@pytest.mark.parametrize("video_pruning_rate", [0, 0.25, 0.75])
|
||
|
|
@pytest.mark.parametrize("interleave_text_tokens", [(0, 0), (1, 4)])
|
||
|
|
def test_match_qwen3vl_mrope_evs_on(
|
||
|
|
spatial_merge_size: int,
|
||
|
|
num_prefix_tokens: int,
|
||
|
|
grid_thw: tuple[int, int, int],
|
||
|
|
num_suffix_tokens: int,
|
||
|
|
video_pruning_rate: float,
|
||
|
|
interleave_text_tokens: tuple[int, int],
|
||
|
|
):
|
||
|
|
hf_config = DummyConfig()
|
||
|
|
hf_config.vision_config.spatial_merge_size = spatial_merge_size
|
||
|
|
|
||
|
|
t, h, w = grid_thw
|
||
|
|
population = list(range(1, 100))
|
||
|
|
prefix_tokens = random.choices(population, k=num_prefix_tokens)
|
||
|
|
suffix_tokens = random.choices(population, k=num_suffix_tokens)
|
||
|
|
|
||
|
|
video_tokens, video_tokens_pruned, retention_mask = make_video_embedding(
|
||
|
|
t,
|
||
|
|
h // spatial_merge_size,
|
||
|
|
w // spatial_merge_size,
|
||
|
|
interleave_text_tokens=interleave_text_tokens,
|
||
|
|
video_pruning_rate=video_pruning_rate,
|
||
|
|
)
|
||
|
|
assert len(video_tokens) == len(retention_mask)
|
||
|
|
|
||
|
|
input_tokens = prefix_tokens + video_tokens.tolist() + suffix_tokens
|
||
|
|
input_tokens_pruned = prefix_tokens + video_tokens_pruned.tolist() + suffix_tokens
|
||
|
|
|
||
|
|
whole_sequence_retention_mask = torch.cat(
|
||
|
|
[
|
||
|
|
torch.ones(len(prefix_tokens), dtype=torch.bool),
|
||
|
|
retention_mask,
|
||
|
|
torch.ones(len(suffix_tokens), dtype=torch.bool),
|
||
|
|
],
|
||
|
|
dim=0,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Build the GT mrope for unpruned input.
|
||
|
|
mm_feature = MultiModalFeatureSpec(
|
||
|
|
data=MultiModalKwargsItem(
|
||
|
|
{
|
||
|
|
"video_grid_thw": MultiModalFieldElem(
|
||
|
|
data=torch.tensor(grid_thw),
|
||
|
|
field=None, # HACK.
|
||
|
|
),
|
||
|
|
}
|
||
|
|
),
|
||
|
|
modality="video",
|
||
|
|
identifier="DUMMY",
|
||
|
|
mm_position=PlaceholderRange(offset=0, length=len(input_tokens)),
|
||
|
|
)
|
||
|
|
expected_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions(
|
||
|
|
input_tokens=input_tokens,
|
||
|
|
mm_features=[mm_feature],
|
||
|
|
config=hf_config,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Compute mrope for a video-only media (unpruned).
|
||
|
|
mm_feature = MultiModalFeatureSpec(
|
||
|
|
data=MultiModalKwargsItem(
|
||
|
|
{
|
||
|
|
"video_grid_thw": MultiModalFieldElem(
|
||
|
|
data=torch.tensor(grid_thw),
|
||
|
|
field=None, # HACK.
|
||
|
|
),
|
||
|
|
}
|
||
|
|
),
|
||
|
|
modality="video",
|
||
|
|
identifier="DUMMY",
|
||
|
|
mm_position=PlaceholderRange(offset=0, length=video_tokens.numel()),
|
||
|
|
)
|
||
|
|
video_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions(
|
||
|
|
input_tokens=video_tokens.tolist(),
|
||
|
|
mm_features=[mm_feature],
|
||
|
|
config=hf_config,
|
||
|
|
)
|
||
|
|
video_mrope = video_mrope.permute(1, 0) # [N, 3]
|
||
|
|
hidden_size = 16
|
||
|
|
|
||
|
|
is_video_embed = torch.isin(
|
||
|
|
video_tokens_pruned, torch.tensor([VIDEO_TOKEN_ID], dtype=torch.long)
|
||
|
|
)
|
||
|
|
|
||
|
|
expanded_positions = torch.full(
|
||
|
|
(len(video_tokens_pruned), 5),
|
||
|
|
fill_value=-100,
|
||
|
|
device=video_mrope.device,
|
||
|
|
dtype=torch.long,
|
||
|
|
)
|
||
|
|
expanded_positions[is_video_embed, :3] = video_mrope[retention_mask][is_video_embed]
|
||
|
|
expanded_positions[~is_video_embed, :3] = video_mrope[retention_mask][
|
||
|
|
~is_video_embed
|
||
|
|
]
|
||
|
|
|
||
|
|
is_vision_start = video_tokens_pruned == VISION_START_TOKEN_ID
|
||
|
|
expanded_positions[..., 3] = is_vision_start
|
||
|
|
expanded_positions[..., 4] = is_video_embed
|
||
|
|
|
||
|
|
# Check that all positions were filled, since we initialized them as negative.
|
||
|
|
assert (expanded_positions >= 0).all()
|
||
|
|
|
||
|
|
video_embeddings = torch.empty(
|
||
|
|
(len(video_tokens_pruned), hidden_size), device=video_mrope.device
|
||
|
|
)
|
||
|
|
|
||
|
|
video_embeddings = torch.cat(
|
||
|
|
[
|
||
|
|
video_embeddings,
|
||
|
|
expanded_positions.float(),
|
||
|
|
],
|
||
|
|
dim=1,
|
||
|
|
)
|
||
|
|
multimodal_embeddings = [video_embeddings]
|
||
|
|
|
||
|
|
expected_mrope_masked = expected_mrope[:, whole_sequence_retention_mask]
|
||
|
|
|
||
|
|
# Initialize computed_mrope with sequential positions for all prefix tokens
|
||
|
|
computed_mrope = torch.empty((3, len(input_tokens_pruned)), dtype=torch.long)
|
||
|
|
computed_mrope[:, 0 : len(prefix_tokens)] = expected_mrope[
|
||
|
|
:, 0 : len(prefix_tokens)
|
||
|
|
]
|
||
|
|
|
||
|
|
# Paranoia check that computed_mrope is wrong.
|
||
|
|
assert not torch.equal(computed_mrope, expected_mrope_masked)
|
||
|
|
|
||
|
|
_, actual_mrope, _ = Qwen3VLForConditionalGeneration._recompute_mrope_positions(
|
||
|
|
input_ids=input_tokens_pruned,
|
||
|
|
multimodal_embeddings=multimodal_embeddings,
|
||
|
|
mrope_positions=computed_mrope,
|
||
|
|
num_computed_tokens=len(prefix_tokens),
|
||
|
|
vision_start_token_id=hf_config.vision_start_token_id,
|
||
|
|
image_token_id=hf_config.image_token_id,
|
||
|
|
video_token_id=hf_config.video_token_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert torch.equal(actual_mrope, expected_mrope_masked)
|