[Scheduler][ASR] Fix CrossAttn blocks per-request for Variable length encoder inputs (#31058)

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
Ekagra Ranjan
2026-02-16 06:08:44 -05:00
committed by GitHub
parent 1e828573b4
commit cd81cdb399
2 changed files with 305 additions and 13 deletions

View File

@@ -3676,6 +3676,300 @@ def test_abort_request_finished_recving():
assert not scheduler.finished_recving_kv_req_ids
# ==============================================================================
# Variable-length encoder cross-attention block allocation tests
# ==============================================================================
def _create_encoder_decoder_scheduler(
block_size: int = 16,
num_blocks: int = 10000,
max_num_batched_tokens: int = 8192,
max_num_seqs: int = 16,
) -> Scheduler:
"""Create a scheduler configured for encoder-decoder cross-attention
block allocation testing.
Constructs a scheduler with both FullAttentionSpec (self-attention) and
CrossAttentionSpec (cross-attention) KV cache groups, then patches it
to behave as an encoder-decoder model.
"""
from vllm.v1.core.encoder_cache_manager import EncoderDecoderCacheManager
from vllm.v1.kv_cache_interface import CrossAttentionSpec
model_config = ModelConfig(
model="facebook/opt-125m",
trust_remote_code=True,
dtype="float16",
seed=42,
)
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
# is_encoder_decoder disables chunked prefill and prefix caching
is_encoder_decoder=True,
)
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=False,
)
cache_config.num_gpu_blocks = num_blocks
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
)
# KV cache config with both self-attention and cross-attention groups,
# mirroring an encoder-decoder model like Whisper.
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["self_attn_layer"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["cross_attn_layer"],
CrossAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
],
)
# Construct the scheduler. Since opt-125m is not truly encoder-decoder,
# the __init__ won't set up encoder-decoder internals. We patch them
# after construction.
scheduler = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
block_size=block_size,
structured_output_manager=StructuredOutputManager(vllm_config),
)
# Patch to enable encoder-decoder behavior in the scheduling loop.
scheduler.is_encoder_decoder = True
scheduler.max_num_encoder_input_tokens = max_num_batched_tokens
scheduler.encoder_cache_manager = EncoderDecoderCacheManager(
cache_size=max_num_batched_tokens
)
return scheduler
def _get_num_cross_attn_blocks(scheduler: Scheduler, request_id: str) -> int:
"""Get the number of cross-attention blocks allocated for a request."""
from vllm.v1.core.single_type_kv_cache_manager import CrossAttentionManager
coordinator = scheduler.kv_cache_manager.coordinator
for manager in coordinator.single_type_managers:
if isinstance(manager, CrossAttentionManager):
blocks = manager.req_to_blocks.get(request_id, [])
return len(blocks)
raise AssertionError("No CrossAttentionManager found in coordinator")
def test_variable_length_cross_attn_block_allocation():
"""Test that cross-attention blocks are allocated per-request based on
actual encoder input length, not a fixed maximum.
Fixed max-encoder-length allocation would assign
`ceil(max_encoder_tokens / block_size)` blocks to
every request whereas with dynamic allocation, exactly
`ceil(actual_encoder_tokens / block_size)` blocks are assigned
to each request.
"""
block_size = 16
scheduler = _create_encoder_decoder_scheduler(block_size=block_size)
# Create requests with distinctly different encoder input lengths,
# simulating variable-length audio inputs to a model like Whisper.
encoder_lengths = [500, 1000, 200]
num_prompt_tokens = 100 # Decoder prompt tokens
requests = []
for i, enc_len in enumerate(encoder_lengths):
req = create_requests(
num_requests=1,
num_tokens=num_prompt_tokens,
mm_hashes_list=[[f"enc_hash_{i}"]],
mm_positions=[[PlaceholderRange(offset=0, length=enc_len)]],
req_ids=[f"req_{i}"],
)[0]
requests.append(req)
# Add and schedule all requests.
for req in requests:
scheduler.add_request(req)
output = scheduler.schedule()
# All requests should be scheduled.
assert len(output.scheduled_new_reqs) == len(requests)
# Verify cross-attention blocks per request match the actual encoder length.
from math import ceil
for req, enc_len in zip(requests, encoder_lengths):
expected_blocks = ceil(enc_len / block_size)
actual_blocks = _get_num_cross_attn_blocks(scheduler, req.request_id)
assert actual_blocks == expected_blocks, (
f"Request {req.request_id} with {enc_len} encoder tokens: "
f"expected {expected_blocks} cross-attn blocks, "
f"got {actual_blocks}"
)
# Verify that different encoder lengths produce different block counts,
# confirming variable-length (not fixed-max) allocation.
block_counts = [
_get_num_cross_attn_blocks(scheduler, req.request_id) for req in requests
]
assert len(set(block_counts)) > 1, (
"All requests have the same number of cross-attn blocks, "
"suggesting static max-based allocation instead of per-request"
)
def test_cross_attn_blocks_not_over_allocated():
"""Test that cross-attention blocks are not over-allocated compared to
what each request actually needs."""
from math import ceil
block_size = 16
max_encoder_tokens = 1500 # e.g., Whisper's max mel-spectrogram length
scheduler = _create_encoder_decoder_scheduler(block_size=block_size)
# Request with a small encoder input (much less than the max).
small_enc_len = 200
request = create_requests(
num_requests=1,
num_tokens=100,
mm_hashes_list=[["enc_small"]],
mm_positions=[[PlaceholderRange(offset=0, length=small_enc_len)]],
req_ids=["req_small"],
)[0]
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
actual_blocks = _get_num_cross_attn_blocks(scheduler, request.request_id)
expected_blocks = ceil(small_enc_len / block_size)
max_blocks = ceil(max_encoder_tokens / block_size)
# Blocks should match the actual encoder length.
assert actual_blocks == expected_blocks, (
f"Expected {expected_blocks} blocks for {small_enc_len} encoder tokens, "
f"got {actual_blocks}"
)
# Blocks should be strictly less than what max-based allocation would give.
assert actual_blocks < max_blocks, (
f"Cross-attn blocks ({actual_blocks}) should be less than max "
f"({max_blocks}), indicating no over-allocation"
)
def test_cross_attn_blocks_not_under_allocated():
"""Test that cross-attention blocks are sufficient for each request's
actual encoder input length. Every encoder token must have a slot.
Tests various edge cases including exact block boundaries, off-by-one,
and the minimum/maximum encoder input sizes.
"""
from math import ceil
block_size = 16
# Test various encoder lengths including edge cases around block boundaries.
test_cases = [
1, # Minimum: single encoder token
block_size - 1, # Just under one full block
block_size, # Exactly one full block
block_size + 1, # Just over one block (needs 2 blocks)
block_size * 10, # Exact multiple of block size
block_size * 10 + 1, # One over exact multiple
1500, # Whisper's typical max
]
for enc_len in test_cases:
scheduler = _create_encoder_decoder_scheduler(block_size=block_size)
request = create_requests(
num_requests=1,
num_tokens=100,
mm_hashes_list=[[f"enc_{enc_len}"]],
mm_positions=[[PlaceholderRange(offset=0, length=enc_len)]],
req_ids=[f"req_{enc_len}"],
)[0]
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
actual_blocks = _get_num_cross_attn_blocks(scheduler, request.request_id)
expected_blocks = ceil(enc_len / block_size)
# Number of blocks must be exactly ceil(enc_len / block_size).
assert actual_blocks == expected_blocks, (
f"Encoder length {enc_len}: expected {expected_blocks} blocks, "
f"got {actual_blocks}"
)
# Total available slots must be >= encoder tokens (no under-allocation).
total_slots = actual_blocks * block_size
assert total_slots >= enc_len, (
f"Encoder length {enc_len}: total slots {total_slots} < "
f"needed {enc_len} (under-allocation)"
)
def test_cross_attn_zero_blocks_without_encoder_inputs():
"""Test that requests without encoder inputs get zero cross-attention
blocks, even when the scheduler is configured for encoder-decoder."""
block_size = 16
scheduler = _create_encoder_decoder_scheduler(block_size=block_size)
# Create a text-only request (no mm_features).
request = create_requests(
num_requests=1,
num_tokens=100,
req_ids=["req_text_only"],
)[0]
# Text-only request has no encoder inputs.
assert not request.has_encoder_inputs
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
# No cross-attention blocks should be allocated.
actual_blocks = _get_num_cross_attn_blocks(scheduler, request.request_id)
assert actual_blocks == 0, (
f"Text-only request should have 0 cross-attn blocks, got {actual_blocks}"
)
def test_eagle3_mm_encoder_cache_with_shift():
"""Test EAGLE3 encoder scheduling accounts for shift_computed_tokens.