[Bugfix] Fix check_interleaved_audio_video false positive for batched non-interleaved requests (#35487)

Signed-off-by: linyueqian <linyueqian@outlook.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Yueqian Lin
2026-02-27 09:48:25 -05:00
committed by GitHub
parent 6d4f9d3ad5
commit e8249378e4
2 changed files with 51 additions and 4 deletions

View File

@@ -122,8 +122,17 @@ def check_interleaved_audio_video(
"""
Check if video and audio positions are interleaved in the multimodal region.
Returns:
True if video and audio tokens are interleaved, False otherwise.
Returns True only for the use_audio_in_video=True case, where video and
audio tokens alternate within a single contiguous region with no gaps.
A simple range-overlap check produces false positives when multiple
non-interleaved requests are batched together: audio tokens from request N
fall between video tokens from request N and request N+1, making the
global ranges overlap even though each individual request is non-interleaved.
To distinguish true interleaving from this batching artefact we require
that every position in the combined [first_VA, last_VA] range is occupied
by either a video or an audio token (no text/image gaps).
"""
if num_video == 0 or num_audio == 0:
return False
@@ -131,10 +140,22 @@ def check_interleaved_audio_video(
video_pos = is_video.nonzero(as_tuple=True)[0]
audio_pos = is_audio.nonzero(as_tuple=True)[0]
return (
# Quick range-overlap pre-check (necessary but not sufficient).
if not (
video_pos[0].item() < audio_pos[-1].item()
and audio_pos[0].item() < video_pos[-1].item()
)
):
return False
# Density check: for true use_audio_in_video interleaving every position
# in the combined span is a video or audio token. Batched non-interleaved
# requests have text/image tokens between the per-request V and A blocks.
# combined_start/end encompass all V/A tokens, so num_video + num_audio
# equals the number of V/A tokens in range; compare directly to span size.
combined_start = min(video_pos[0].item(), audio_pos[0].item())
combined_end = max(video_pos[-1].item(), audio_pos[-1].item())
total_in_range = combined_end - combined_start + 1
return (num_video + num_audio) == total_in_range
def merge_interleaved_embeddings(