[Bugfix] Fix EVS implementation for Qwen3 VL (#33607)
Signed-off-by: 2ez4bz <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
237
tests/model_executor/test_qwen3_vl_mrope.py
Normal file
237
tests/model_executor/test_qwen3_vl_mrope.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldElem,
|
||||
MultiModalKwargsItem,
|
||||
PlaceholderRange,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def _force_cpu_default_device():
|
||||
# _get_mrope_input_positions returns CPU tensors (via torch.from_numpy).
|
||||
# Ensure the default device is CPU so the rest of the test tensors match.
|
||||
original = torch.get_default_device()
|
||||
torch.set_default_device("cpu")
|
||||
yield
|
||||
torch.set_default_device(original)
|
||||
|
||||
|
||||
IMAGE_TOKEN_ID = 999
|
||||
VIDEO_TOKEN_ID = 888
|
||||
VISION_START_TOKEN_ID = 777
|
||||
VISION_END_TOKEN_ID = 778
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyVisionConfig:
|
||||
spatial_merge_size: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyConfig:
|
||||
image_token_id: int = IMAGE_TOKEN_ID
|
||||
video_token_id: int = VIDEO_TOKEN_ID
|
||||
vision_start_token_id: int = VISION_START_TOKEN_ID
|
||||
vision_end_token_id: int = VISION_END_TOKEN_ID
|
||||
vision_config: DummyVisionConfig = dataclasses.field(
|
||||
default_factory=DummyVisionConfig
|
||||
)
|
||||
|
||||
|
||||
def make_video_embedding(
|
||||
t, h, w, interleave_text_tokens: tuple[int, int], video_pruning_rate: float = 0.0
|
||||
):
|
||||
"""
|
||||
Helper function to make a video embedding for a given video size and pruning rate.
|
||||
|
||||
Args:
|
||||
t: Number of frames.
|
||||
h: Number of rows.
|
||||
w: Number of columns.
|
||||
interleave_text_tokens: Tuple of minimum and maximum number of text tokens to
|
||||
interleave with the video.
|
||||
video_pruning_rate: Pruning rate for the video.
|
||||
|
||||
Returns:
|
||||
Tuple of (unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask)
|
||||
"""
|
||||
unpruned_tokens_sequence = []
|
||||
population = list(range(1, 100))
|
||||
|
||||
for _ in range(t):
|
||||
num_prefix_tokens = random.randint(
|
||||
interleave_text_tokens[0], interleave_text_tokens[1]
|
||||
)
|
||||
|
||||
prefix_tokens = random.choices(population, k=num_prefix_tokens)
|
||||
vision_tokens = (
|
||||
[VISION_START_TOKEN_ID] + [VIDEO_TOKEN_ID] * h * w + [VISION_END_TOKEN_ID]
|
||||
)
|
||||
|
||||
unpruned_tokens_sequence.extend(prefix_tokens)
|
||||
unpruned_tokens_sequence.extend(vision_tokens)
|
||||
|
||||
unpruned_tokens_sequence = torch.tensor(unpruned_tokens_sequence, dtype=torch.long)
|
||||
video_token_mask = unpruned_tokens_sequence == VIDEO_TOKEN_ID
|
||||
|
||||
pruning_mask = torch.bernoulli(video_token_mask.float() * video_pruning_rate).bool() # type: ignore[attr-defined]
|
||||
# Sanity check that we don't prune what should not be pruned.
|
||||
assert not pruning_mask[~video_token_mask].any()
|
||||
|
||||
retention_mask = ~pruning_mask
|
||||
pruned_tokens_sequence = unpruned_tokens_sequence[retention_mask]
|
||||
return unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spatial_merge_size", [1, 2])
|
||||
@pytest.mark.parametrize("grid_thw", [[3, 8, 7], [128, 10, 12]])
|
||||
@pytest.mark.parametrize("num_prefix_tokens", [1, 11])
|
||||
@pytest.mark.parametrize("num_suffix_tokens", [0, 7])
|
||||
@pytest.mark.parametrize("video_pruning_rate", [0, 0.25, 0.75])
|
||||
@pytest.mark.parametrize("interleave_text_tokens", [(0, 0), (1, 4)])
|
||||
def test_match_qwen3vl_mrope_evs_on(
|
||||
spatial_merge_size: int,
|
||||
num_prefix_tokens: int,
|
||||
grid_thw: tuple[int, int, int],
|
||||
num_suffix_tokens: int,
|
||||
video_pruning_rate: float,
|
||||
interleave_text_tokens: tuple[int, int],
|
||||
):
|
||||
hf_config = DummyConfig()
|
||||
hf_config.vision_config.spatial_merge_size = spatial_merge_size
|
||||
|
||||
t, h, w = grid_thw
|
||||
population = list(range(1, 100))
|
||||
prefix_tokens = random.choices(population, k=num_prefix_tokens)
|
||||
suffix_tokens = random.choices(population, k=num_suffix_tokens)
|
||||
|
||||
video_tokens, video_tokens_pruned, retention_mask = make_video_embedding(
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
interleave_text_tokens=interleave_text_tokens,
|
||||
video_pruning_rate=video_pruning_rate,
|
||||
)
|
||||
assert len(video_tokens) == len(retention_mask)
|
||||
|
||||
input_tokens = prefix_tokens + video_tokens.tolist() + suffix_tokens
|
||||
input_tokens_pruned = prefix_tokens + video_tokens_pruned.tolist() + suffix_tokens
|
||||
|
||||
whole_sequence_retention_mask = torch.cat(
|
||||
[
|
||||
torch.ones(len(prefix_tokens), dtype=torch.bool),
|
||||
retention_mask,
|
||||
torch.ones(len(suffix_tokens), dtype=torch.bool),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Build the GT mrope for unpruned input.
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem(
|
||||
{
|
||||
"video_grid_thw": MultiModalFieldElem(
|
||||
data=torch.tensor(grid_thw),
|
||||
field=None, # HACK.
|
||||
),
|
||||
}
|
||||
),
|
||||
modality="video",
|
||||
identifier="DUMMY",
|
||||
mm_position=PlaceholderRange(offset=0, length=len(input_tokens)),
|
||||
)
|
||||
expected_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions(
|
||||
input_tokens=input_tokens,
|
||||
mm_features=[mm_feature],
|
||||
config=hf_config,
|
||||
)
|
||||
|
||||
# Compute mrope for a video-only media (unpruned).
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem(
|
||||
{
|
||||
"video_grid_thw": MultiModalFieldElem(
|
||||
data=torch.tensor(grid_thw),
|
||||
field=None, # HACK.
|
||||
),
|
||||
}
|
||||
),
|
||||
modality="video",
|
||||
identifier="DUMMY",
|
||||
mm_position=PlaceholderRange(offset=0, length=video_tokens.numel()),
|
||||
)
|
||||
video_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions(
|
||||
input_tokens=video_tokens.tolist(),
|
||||
mm_features=[mm_feature],
|
||||
config=hf_config,
|
||||
)
|
||||
video_mrope = video_mrope.permute(1, 0) # [N, 3]
|
||||
hidden_size = 16
|
||||
|
||||
is_video_embed = torch.isin(
|
||||
video_tokens_pruned, torch.tensor([VIDEO_TOKEN_ID], dtype=torch.long)
|
||||
)
|
||||
|
||||
expanded_positions = torch.full(
|
||||
(len(video_tokens_pruned), 5),
|
||||
fill_value=-100,
|
||||
device=video_mrope.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
expanded_positions[is_video_embed, :3] = video_mrope[retention_mask][is_video_embed]
|
||||
expanded_positions[~is_video_embed, :3] = video_mrope[retention_mask][
|
||||
~is_video_embed
|
||||
]
|
||||
|
||||
is_vision_start = video_tokens_pruned == VISION_START_TOKEN_ID
|
||||
expanded_positions[..., 3] = is_vision_start
|
||||
expanded_positions[..., 4] = is_video_embed
|
||||
|
||||
# Check that all positions were filled, since we initialized them as negative.
|
||||
assert (expanded_positions >= 0).all()
|
||||
|
||||
video_embeddings = torch.empty(
|
||||
(len(video_tokens_pruned), hidden_size), device=video_mrope.device
|
||||
)
|
||||
|
||||
video_embeddings = torch.cat(
|
||||
[
|
||||
video_embeddings,
|
||||
expanded_positions.float(),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
multimodal_embeddings = [video_embeddings]
|
||||
|
||||
expected_mrope_masked = expected_mrope[:, whole_sequence_retention_mask]
|
||||
|
||||
# Initialize computed_mrope with sequential positions for all prefix tokens
|
||||
computed_mrope = torch.empty((3, len(input_tokens_pruned)), dtype=torch.long)
|
||||
computed_mrope[:, 0 : len(prefix_tokens)] = expected_mrope[
|
||||
:, 0 : len(prefix_tokens)
|
||||
]
|
||||
|
||||
# Paranoia check that computed_mrope is wrong.
|
||||
assert not torch.equal(computed_mrope, expected_mrope_masked)
|
||||
|
||||
_, actual_mrope, _ = Qwen3VLForConditionalGeneration._recompute_mrope_positions(
|
||||
input_ids=input_tokens_pruned,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
mrope_positions=computed_mrope,
|
||||
num_computed_tokens=len(prefix_tokens),
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
)
|
||||
|
||||
assert torch.equal(actual_mrope, expected_mrope_masked)
|
||||
@@ -195,6 +195,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
|
||||
- second_per_grid_ts: The video time interval (in seconds) for each
|
||||
grid along the temporal dimension in the 3D position IDs. Returned
|
||||
when `videos` is not `None`.
|
||||
- timestamps: List of timestamp values (in seconds) for each frame
|
||||
after merging. Length equals the temporal dimension after merging.
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values_videos"]
|
||||
@@ -214,6 +216,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
|
||||
TensorShape("nv"),
|
||||
]
|
||||
|
||||
timestamps: list[list[float]] | None = None
|
||||
|
||||
|
||||
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
@@ -232,6 +236,8 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
|
||||
- second_per_grid_ts: The video time interval (in seconds) for each
|
||||
grid along the temporal dimension in the 3D position IDs. Returned
|
||||
when `videos` is not `None`.
|
||||
- timestamps: List of timestamp values (in seconds) for each frame
|
||||
after merging. Length equals the temporal dimension after merging.
|
||||
"""
|
||||
|
||||
type: Literal["video_embeds"]
|
||||
@@ -250,6 +256,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
|
||||
torch.Tensor | None,
|
||||
TensorShape("nv"),
|
||||
] = None
|
||||
timestamps: list[list[float]] | None = None
|
||||
|
||||
|
||||
Qwen2_5_VLVideoInputs: TypeAlias = (
|
||||
|
||||
@@ -755,6 +755,7 @@ def _create_qwen2vl_field_factory(
|
||||
"video", video_embed_grid_sizes
|
||||
),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
|
||||
timestamps=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
|
||||
)
|
||||
|
||||
return _qwen2vl_field_config
|
||||
|
||||
@@ -628,6 +628,9 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder,
|
||||
)
|
||||
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
|
||||
# Qwen3.5 does not support multimodal pruning (EVS).
|
||||
supports_multimodal_pruning = False
|
||||
|
||||
packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | {
|
||||
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
|
||||
"in_proj_ba": ["in_proj_b", "in_proj_a"],
|
||||
@@ -643,10 +646,8 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
multimodal_config.is_multimodal_pruning_enabled()
|
||||
)
|
||||
# Qwen3.5 does not support multimodal pruning (EVS).
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
@@ -693,6 +694,12 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def recompute_mrope_positions(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Qwen3.5 does not support multimodal pruning (EVS). "
|
||||
"recompute_mrope_positions should never be called."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -851,10 +858,8 @@ class Qwen3_5MoeForConditionalGeneration(
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
multimodal_config.is_multimodal_pruning_enabled()
|
||||
)
|
||||
# Qwen3.5 does not support multimodal pruning (EVS).
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers.registry import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import MixtureOfExperts
|
||||
from .qwen3_moe import (
|
||||
@@ -415,6 +416,7 @@ class Qwen3VLMoeForConditionalGeneration(
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config
|
||||
self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
|
||||
@@ -170,9 +170,9 @@ def recompute_mrope_positions(
|
||||
multimodal_embeddings may contain zero, some or even some part of all
|
||||
multimodal_embeddings for a given prompt.
|
||||
|
||||
Each multimodal_positions has 4 extra channels
|
||||
(First 3 channels corresponds to original 3 mrope positions, last channel
|
||||
is the maximum width of the media repeated). Provided multimodal_positions
|
||||
Each multimodal_positions has 4 or 5 extra channels
|
||||
(first 3 channels correspond to the original 3 mrope positions;
|
||||
remaining channels vary by model — see below). Provided multimodal_positions
|
||||
do not reflect location of media position in sequence - they are computed
|
||||
like the media is in the 0-th position in the sequence.
|
||||
|
||||
@@ -186,6 +186,16 @@ def recompute_mrope_positions(
|
||||
Args:
|
||||
input_ids: (N,) All input tokens of the prompt (entire sequence).
|
||||
multimodal_positions: List of mrope positions for each media.
|
||||
If a given element is of shape (4, N), it is assumed to only describe
|
||||
positions for video / image embeddings. This is the case of e.g. Qwen2.5 VL,
|
||||
where each multimodal input is a contiguous chunk of embeddings.
|
||||
The expected channels are [t, h, w, max_width].
|
||||
If it is of shape (5, N), it is assumed to possibly describe positions for
|
||||
both video / image embeddings, as well as text embeddings. This is the case
|
||||
of e.g. Qwen3 VL, where each video inputs are comprised of individual
|
||||
frames' embeddings, interleaved with embeddings for timestamp tokens,
|
||||
and vision start / end tokens. The expected channels are
|
||||
[t, h, w, is_vision_start, is_vision].
|
||||
mrope_positions: Existing mrope positions (4, N) for entire sequence.
|
||||
num_computed_tokens: A number of computed tokens so far.
|
||||
vision_start_token_id: Token indicating start of vision media.
|
||||
@@ -233,6 +243,21 @@ def recompute_mrope_positions(
|
||||
# - Current prefill chunk has no vision start indexes at all
|
||||
# - Vision start token appeared in previous prefill round
|
||||
# - Regular case
|
||||
has_video_tokens = False
|
||||
num_timestamp_tokens = 0
|
||||
if mm_pos.shape[0] == 5 and mm_pos.shape[1] > 0:
|
||||
# mm_pos[4, :] indicates which positions are for video embeddings.
|
||||
# If there are no video embeddings, skip timestamp adjustment.
|
||||
has_video_tokens = torch.any(mm_pos[4, :]).item()
|
||||
if has_video_tokens:
|
||||
# Channel 3 flags VISION_START tokens. Timestamp tokens
|
||||
# precede the first VISION_START, so its index gives us the
|
||||
# exact timestamp count. This is robust even when early
|
||||
# frames have all their video tokens pruned (which would
|
||||
# push argmax(channel 4) far into a later frame).
|
||||
first_vs = (mm_pos[3, :] == 1).nonzero(as_tuple=True)[0]
|
||||
num_timestamp_tokens = first_vs[0].item() if len(first_vs) > 0 else 0
|
||||
|
||||
seen_vision_start_indices = vision_start_indices[
|
||||
vision_start_indices < num_computed_tokens
|
||||
]
|
||||
@@ -249,6 +274,18 @@ def recompute_mrope_positions(
|
||||
in_the_middle_of_media = (
|
||||
seen_mm_tokens > seem_mm_tokens_before_last_vision_start
|
||||
)
|
||||
# For Qwen3 VL, we can be inside a media segment even before any
|
||||
# video tokens appear (timestamp tokens are text). If we've passed
|
||||
# the last vision_start token but haven't reached the first video
|
||||
# embedding, treat this as "in the middle of media".
|
||||
if (
|
||||
not in_the_middle_of_media
|
||||
and has_video_tokens
|
||||
and num_computed_tokens > last_vision_start_token
|
||||
and num_computed_tokens
|
||||
<= last_vision_start_token + num_timestamp_tokens + 1
|
||||
):
|
||||
in_the_middle_of_media = True
|
||||
|
||||
if in_the_middle_of_media:
|
||||
mm_embeddings_seen = (
|
||||
@@ -274,14 +311,39 @@ def recompute_mrope_positions(
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
# Offset right after vision_start_token
|
||||
base = positions[-1, global_mm_start] + 1
|
||||
local_start = global_mm_start + 1 + mm_embeddings_seen
|
||||
# For Qwen3 VL, mm_pos includes timestamp tokens before vision_start
|
||||
# when starting a new media. Adjust global_mm_start to point to where
|
||||
# the sequence actually begins (before timestamp tokens).
|
||||
adjusted_for_timestamps = False
|
||||
if mm_pos.shape[0] == 5 and mm_embeddings_seen == 0 and has_video_tokens:
|
||||
# NOTE: -1 is because there is a vision start token right after
|
||||
# timestamp tokens before any video embeddings appear.
|
||||
|
||||
# Adjust global_mm_start to point to the first timestamp token
|
||||
# instead of the vision_start token.
|
||||
global_mm_start -= num_timestamp_tokens
|
||||
adjusted_for_timestamps = True
|
||||
|
||||
# Offset calculation depends on whether we adjusted for timestamp tokens
|
||||
if adjusted_for_timestamps:
|
||||
# Start from position before the first timestamp token
|
||||
base = positions[-1, global_mm_start - 1] + 1
|
||||
local_start = global_mm_start + mm_embeddings_seen
|
||||
else:
|
||||
# Original logic: start after vision_start_token
|
||||
base = positions[-1, global_mm_start] + 1
|
||||
local_start = global_mm_start + 1 + mm_embeddings_seen
|
||||
|
||||
local_end = local_start + mm_pos.shape[1]
|
||||
positions[:, local_start:local_end] = mm_pos[0:3] + base
|
||||
|
||||
# mm_pos[3, 0] is the max width of the media
|
||||
offset = mm_pos[3, 0] + base
|
||||
# For Qwen3 VL (5-channel), use the maximum position reached across
|
||||
# all tokens (both video and text) in all dimensions (t, h, w).
|
||||
# For Qwen2.5 VL (4-channel), mm_pos[3, 0] is the max width.
|
||||
if mm_pos.shape[0] == 5:
|
||||
offset = mm_pos[0:3, :].max() + base + 1
|
||||
else:
|
||||
offset = mm_pos[3, 0] + base
|
||||
|
||||
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user