[Bugfix] Fix check_interleaved_audio_video false positive for batched non-interleaved requests (#35487)
Signed-off-by: linyueqian <linyueqian@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -122,8 +122,17 @@ def check_interleaved_audio_video(
|
||||
"""
|
||||
Check if video and audio positions are interleaved in the multimodal region.
|
||||
|
||||
Returns:
|
||||
True if video and audio tokens are interleaved, False otherwise.
|
||||
Returns True only for the use_audio_in_video=True case, where video and
|
||||
audio tokens alternate within a single contiguous region with no gaps.
|
||||
|
||||
A simple range-overlap check produces false positives when multiple
|
||||
non-interleaved requests are batched together: audio tokens from request N
|
||||
fall between video tokens from request N and request N+1, making the
|
||||
global ranges overlap even though each individual request is non-interleaved.
|
||||
|
||||
To distinguish true interleaving from this batching artefact we require
|
||||
that every position in the combined [first_VA, last_VA] range is occupied
|
||||
by either a video or an audio token (no text/image gaps).
|
||||
"""
|
||||
if num_video == 0 or num_audio == 0:
|
||||
return False
|
||||
@@ -131,10 +140,22 @@ def check_interleaved_audio_video(
|
||||
video_pos = is_video.nonzero(as_tuple=True)[0]
|
||||
audio_pos = is_audio.nonzero(as_tuple=True)[0]
|
||||
|
||||
return (
|
||||
# Quick range-overlap pre-check (necessary but not sufficient).
|
||||
if not (
|
||||
video_pos[0].item() < audio_pos[-1].item()
|
||||
and audio_pos[0].item() < video_pos[-1].item()
|
||||
)
|
||||
):
|
||||
return False
|
||||
|
||||
# Density check: for true use_audio_in_video interleaving every position
|
||||
# in the combined span is a video or audio token. Batched non-interleaved
|
||||
# requests have text/image tokens between the per-request V and A blocks.
|
||||
# combined_start/end encompass all V/A tokens, so num_video + num_audio
|
||||
# equals the number of V/A tokens in range; compare directly to span size.
|
||||
combined_start = min(video_pos[0].item(), audio_pos[0].item())
|
||||
combined_end = max(video_pos[-1].item(), audio_pos[-1].item())
|
||||
total_in_range = combined_end - combined_start + 1
|
||||
return (num_video + num_audio) == total_in_range
|
||||
|
||||
|
||||
def merge_interleaved_embeddings(
|
||||
|
||||
Reference in New Issue
Block a user