[Bugfix] Fix check_interleaved_audio_video false positive for batched non-interleaved requests (#35487)

Signed-off-by: linyueqian <linyueqian@outlook.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Yueqian Lin
2026-02-27 09:48:25 -05:00
committed by GitHub
parent 6d4f9d3ad5
commit e8249378e4
2 changed files with 51 additions and 4 deletions

View File

@@ -116,6 +116,32 @@ class TestCheckInterleavedAudioVideo:
is_video, is_audio, is_video.sum().item(), is_audio.sum().item()
)
def test_batched_non_interleaved_no_false_positive(self):
"""
Regression test for https://github.com/vllm-project/vllm/issues/35394.
5 identical non-interleaved mixed-modality requests batched together:
each has [audio][image][video] in separate blocks with text between them.
Across the batch, audio from request N falls between video blocks of
request N and request N+1, causing the global ranges to overlap.
check_interleaved_audio_video must return False (not a false positive).
"""
# Build one request: [text][audio*5][text][image*4][text][video*6][text]
single_ids, _ = make_token_seq(5, 4, 6)
# Batch 5 identical requests (separated by text tokens to simulate padding)
sep = torch.tensor([TEXT_TOKEN_ID] * 3)
batched_ids = torch.cat([single_ids, sep] * 5)
is_multimodal = (
(batched_ids == AUDIO_TOKEN_ID)
| (batched_ids == IMAGE_TOKEN_ID)
| (batched_ids == VIDEO_TOKEN_ID)
)
is_video = is_multimodal & (batched_ids == VIDEO_TOKEN_ID)
is_audio = is_multimodal & (batched_ids == AUDIO_TOKEN_ID)
assert not check_interleaved_audio_video(
is_video, is_audio, is_video.sum().item(), is_audio.sum().item()
), "Batched non-interleaved requests should not be detected as interleaved"
# ---------------------------------------------------------------------------
# Tests for embed_input_ids via a minimal mock

View File

@@ -122,8 +122,17 @@ def check_interleaved_audio_video(
"""
Check if video and audio positions are interleaved in the multimodal region.
Returns:
True if video and audio tokens are interleaved, False otherwise.
Returns True only for the use_audio_in_video=True case, where video and
audio tokens alternate within a single contiguous region with no gaps.
A simple range-overlap check produces false positives when multiple
non-interleaved requests are batched together: audio tokens from request N
fall between video tokens from request N and request N+1, making the
global ranges overlap even though each individual request is non-interleaved.
To distinguish true interleaving from this batching artefact we require
that every position in the combined [first_VA, last_VA] range is occupied
by either a video or an audio token (no text/image gaps).
"""
if num_video == 0 or num_audio == 0:
return False
@@ -131,10 +140,22 @@ def check_interleaved_audio_video(
video_pos = is_video.nonzero(as_tuple=True)[0]
audio_pos = is_audio.nonzero(as_tuple=True)[0]
return (
# Quick range-overlap pre-check (necessary but not sufficient).
if not (
video_pos[0].item() < audio_pos[-1].item()
and audio_pos[0].item() < video_pos[-1].item()
)
):
return False
# Density check: for true use_audio_in_video interleaving every position
# in the combined span is a video or audio token. Batched non-interleaved
# requests have text/image tokens between the per-request V and A blocks.
# combined_start/end encompass all V/A tokens, so num_video + num_audio
# equals the number of V/A tokens in range; compare directly to span size.
combined_start = min(video_pos[0].item(), audio_pos[0].item())
combined_end = max(video_pos[-1].item(), audio_pos[-1].item())
total_in_range = combined_end - combined_start + 1
return (num_video + num_audio) == total_in_range
def merge_interleaved_embeddings(