[V1][BugFix] Fix EAGLE3 encoder cache miss with disable_chunked_mm_input (#34220)
Signed-off-by: KrxGu <krishom70@gmail.com>
This commit is contained in:
@@ -3675,3 +3675,72 @@ def test_abort_request_finished_recving():
|
||||
# verify request is deleted
|
||||
assert request.request_id not in scheduler.requests
|
||||
assert not scheduler.finished_recving_kv_req_ids
|
||||
|
||||
|
||||
def test_eagle3_mm_encoder_cache_with_shift():
|
||||
"""Test EAGLE3 encoder scheduling accounts for shift_computed_tokens.
|
||||
|
||||
Regression test for issue #32469: When EAGLE3 is enabled with
|
||||
disable_chunked_mm_input=True, ensure encoder inputs are scheduled
|
||||
when tokens overlap the MM range, properly accounting for
|
||||
shift_computed_tokens in the boundary calculation.
|
||||
|
||||
Without the fix, the scheduler would fail to schedule encoder inputs
|
||||
at the boundary, causing "Encoder cache miss" errors.
|
||||
"""
|
||||
scheduler = create_scheduler(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
max_num_batched_tokens=1024,
|
||||
disable_chunked_mm_input=True,
|
||||
max_model_len=2048,
|
||||
num_speculative_tokens=4, # This enables EAGLE with shift=1
|
||||
)
|
||||
|
||||
mm_start_pos = 100
|
||||
mm_length = 576
|
||||
|
||||
mm_positions = [
|
||||
[PlaceholderRange(offset=mm_start_pos, length=mm_length)],
|
||||
]
|
||||
|
||||
requests = create_requests(
|
||||
num_requests=1,
|
||||
num_tokens=mm_start_pos + mm_length + 100,
|
||||
mm_positions=mm_positions,
|
||||
)
|
||||
|
||||
# Start with some tokens already computed to simulate decoding
|
||||
request = requests[0]
|
||||
request.num_computed_tokens = 0
|
||||
|
||||
scheduler.add_request(request)
|
||||
output = scheduler.schedule()
|
||||
|
||||
assert output is not None
|
||||
shift_computed_tokens = 1
|
||||
req_id = request.request_id
|
||||
|
||||
assert req_id in output.num_scheduled_tokens
|
||||
num_scheduled = output.num_scheduled_tokens[req_id]
|
||||
|
||||
mm_feature = request.mm_features[0]
|
||||
start_pos = mm_feature.mm_position.offset
|
||||
tokens_end = request.num_computed_tokens + num_scheduled
|
||||
scheduled_end_with_shift = tokens_end + shift_computed_tokens
|
||||
|
||||
# Assert that we scheduled into the MM range (test setup verification)
|
||||
assert scheduled_end_with_shift > start_pos, (
|
||||
f"Test setup error: expected to schedule into MM range. "
|
||||
f"scheduled_end_with_shift={scheduled_end_with_shift}, "
|
||||
f"start_pos={start_pos}"
|
||||
)
|
||||
|
||||
# The key assertion: when scheduled tokens overlap MM range
|
||||
# (accounting for EAGLE's shift), encoder MUST be scheduled.
|
||||
# Without the fix, this would fail at the boundary case.
|
||||
assert req_id in output.scheduled_encoder_inputs, (
|
||||
f"Encoder input missing: scheduled {num_scheduled} tokens "
|
||||
f"(computed={request.num_computed_tokens}, end={tokens_end}, "
|
||||
f"shifted_end={scheduled_end_with_shift}) overlapping MM at "
|
||||
f"{start_pos}. The fix must schedule encoder inputs."
|
||||
)
|
||||
|
||||
@@ -1155,7 +1155,12 @@ class Scheduler(SchedulerInterface):
|
||||
and (num_computed_tokens + num_new_tokens)
|
||||
< (start_pos + num_encoder_tokens)
|
||||
):
|
||||
num_new_tokens = start_pos - num_computed_tokens
|
||||
# Account for EAGLE shift when rolling back to avoid
|
||||
# encoder cache miss. This ensures the scheduled range
|
||||
# stops before start_pos even with the shift.
|
||||
num_new_tokens = max(
|
||||
0, start_pos - (num_computed_tokens + shift_computed_tokens)
|
||||
)
|
||||
break
|
||||
if not self.encoder_cache_manager.can_allocate(
|
||||
request, i, encoder_compute_budget, num_embeds_to_schedule
|
||||
|
||||
Reference in New Issue
Block a user