[Scheduler][ASR] Fix CrossAttn blocks per-request for Variable length encoder inputs (#31058)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -3676,6 +3676,300 @@ def test_abort_request_finished_recving():
|
||||
assert not scheduler.finished_recving_kv_req_ids
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Variable-length encoder cross-attention block allocation tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def _create_encoder_decoder_scheduler(
|
||||
block_size: int = 16,
|
||||
num_blocks: int = 10000,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
max_num_seqs: int = 16,
|
||||
) -> Scheduler:
|
||||
"""Create a scheduler configured for encoder-decoder cross-attention
|
||||
block allocation testing.
|
||||
|
||||
Constructs a scheduler with both FullAttentionSpec (self-attention) and
|
||||
CrossAttentionSpec (cross-attention) KV cache groups, then patches it
|
||||
to behave as an encoder-decoder model.
|
||||
"""
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderDecoderCacheManager
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec
|
||||
|
||||
model_config = ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
# is_encoder_decoder disables chunked prefill and prefix caching
|
||||
is_encoder_decoder=True,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
enable_prefix_caching=False,
|
||||
)
|
||||
cache_config.num_gpu_blocks = num_blocks
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
|
||||
# KV cache config with both self-attention and cross-attention groups,
|
||||
# mirroring an encoder-decoder model like Whisper.
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["self_attn_layer"],
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["cross_attn_layer"],
|
||||
CrossAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Construct the scheduler. Since opt-125m is not truly encoder-decoder,
|
||||
# the __init__ won't set up encoder-decoder internals. We patch them
|
||||
# after construction.
|
||||
scheduler = Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
block_size=block_size,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
)
|
||||
|
||||
# Patch to enable encoder-decoder behavior in the scheduling loop.
|
||||
scheduler.is_encoder_decoder = True
|
||||
scheduler.max_num_encoder_input_tokens = max_num_batched_tokens
|
||||
scheduler.encoder_cache_manager = EncoderDecoderCacheManager(
|
||||
cache_size=max_num_batched_tokens
|
||||
)
|
||||
|
||||
return scheduler
|
||||
|
||||
|
||||
def _get_num_cross_attn_blocks(scheduler: Scheduler, request_id: str) -> int:
|
||||
"""Get the number of cross-attention blocks allocated for a request."""
|
||||
from vllm.v1.core.single_type_kv_cache_manager import CrossAttentionManager
|
||||
|
||||
coordinator = scheduler.kv_cache_manager.coordinator
|
||||
for manager in coordinator.single_type_managers:
|
||||
if isinstance(manager, CrossAttentionManager):
|
||||
blocks = manager.req_to_blocks.get(request_id, [])
|
||||
return len(blocks)
|
||||
raise AssertionError("No CrossAttentionManager found in coordinator")
|
||||
|
||||
|
||||
def test_variable_length_cross_attn_block_allocation():
|
||||
"""Test that cross-attention blocks are allocated per-request based on
|
||||
actual encoder input length, not a fixed maximum.
|
||||
|
||||
Fixed max-encoder-length allocation would assign
|
||||
`ceil(max_encoder_tokens / block_size)` blocks to
|
||||
every request whereas with dynamic allocation, exactly
|
||||
`ceil(actual_encoder_tokens / block_size)` blocks are assigned
|
||||
to each request.
|
||||
"""
|
||||
block_size = 16
|
||||
scheduler = _create_encoder_decoder_scheduler(block_size=block_size)
|
||||
|
||||
# Create requests with distinctly different encoder input lengths,
|
||||
# simulating variable-length audio inputs to a model like Whisper.
|
||||
encoder_lengths = [500, 1000, 200]
|
||||
num_prompt_tokens = 100 # Decoder prompt tokens
|
||||
|
||||
requests = []
|
||||
for i, enc_len in enumerate(encoder_lengths):
|
||||
req = create_requests(
|
||||
num_requests=1,
|
||||
num_tokens=num_prompt_tokens,
|
||||
mm_hashes_list=[[f"enc_hash_{i}"]],
|
||||
mm_positions=[[PlaceholderRange(offset=0, length=enc_len)]],
|
||||
req_ids=[f"req_{i}"],
|
||||
)[0]
|
||||
requests.append(req)
|
||||
|
||||
# Add and schedule all requests.
|
||||
for req in requests:
|
||||
scheduler.add_request(req)
|
||||
|
||||
output = scheduler.schedule()
|
||||
|
||||
# All requests should be scheduled.
|
||||
assert len(output.scheduled_new_reqs) == len(requests)
|
||||
|
||||
# Verify cross-attention blocks per request match the actual encoder length.
|
||||
from math import ceil
|
||||
|
||||
for req, enc_len in zip(requests, encoder_lengths):
|
||||
expected_blocks = ceil(enc_len / block_size)
|
||||
actual_blocks = _get_num_cross_attn_blocks(scheduler, req.request_id)
|
||||
|
||||
assert actual_blocks == expected_blocks, (
|
||||
f"Request {req.request_id} with {enc_len} encoder tokens: "
|
||||
f"expected {expected_blocks} cross-attn blocks, "
|
||||
f"got {actual_blocks}"
|
||||
)
|
||||
|
||||
# Verify that different encoder lengths produce different block counts,
|
||||
# confirming variable-length (not fixed-max) allocation.
|
||||
block_counts = [
|
||||
_get_num_cross_attn_blocks(scheduler, req.request_id) for req in requests
|
||||
]
|
||||
assert len(set(block_counts)) > 1, (
|
||||
"All requests have the same number of cross-attn blocks, "
|
||||
"suggesting static max-based allocation instead of per-request"
|
||||
)
|
||||
|
||||
|
||||
def test_cross_attn_blocks_not_over_allocated():
|
||||
"""Test that cross-attention blocks are not over-allocated compared to
|
||||
what each request actually needs."""
|
||||
from math import ceil
|
||||
|
||||
block_size = 16
|
||||
max_encoder_tokens = 1500 # e.g., Whisper's max mel-spectrogram length
|
||||
scheduler = _create_encoder_decoder_scheduler(block_size=block_size)
|
||||
|
||||
# Request with a small encoder input (much less than the max).
|
||||
small_enc_len = 200
|
||||
request = create_requests(
|
||||
num_requests=1,
|
||||
num_tokens=100,
|
||||
mm_hashes_list=[["enc_small"]],
|
||||
mm_positions=[[PlaceholderRange(offset=0, length=small_enc_len)]],
|
||||
req_ids=["req_small"],
|
||||
)[0]
|
||||
|
||||
scheduler.add_request(request)
|
||||
output = scheduler.schedule()
|
||||
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
|
||||
actual_blocks = _get_num_cross_attn_blocks(scheduler, request.request_id)
|
||||
expected_blocks = ceil(small_enc_len / block_size)
|
||||
max_blocks = ceil(max_encoder_tokens / block_size)
|
||||
|
||||
# Blocks should match the actual encoder length.
|
||||
assert actual_blocks == expected_blocks, (
|
||||
f"Expected {expected_blocks} blocks for {small_enc_len} encoder tokens, "
|
||||
f"got {actual_blocks}"
|
||||
)
|
||||
|
||||
# Blocks should be strictly less than what max-based allocation would give.
|
||||
assert actual_blocks < max_blocks, (
|
||||
f"Cross-attn blocks ({actual_blocks}) should be less than max "
|
||||
f"({max_blocks}), indicating no over-allocation"
|
||||
)
|
||||
|
||||
|
||||
def test_cross_attn_blocks_not_under_allocated():
|
||||
"""Test that cross-attention blocks are sufficient for each request's
|
||||
actual encoder input length. Every encoder token must have a slot.
|
||||
|
||||
Tests various edge cases including exact block boundaries, off-by-one,
|
||||
and the minimum/maximum encoder input sizes.
|
||||
"""
|
||||
from math import ceil
|
||||
|
||||
block_size = 16
|
||||
|
||||
# Test various encoder lengths including edge cases around block boundaries.
|
||||
test_cases = [
|
||||
1, # Minimum: single encoder token
|
||||
block_size - 1, # Just under one full block
|
||||
block_size, # Exactly one full block
|
||||
block_size + 1, # Just over one block (needs 2 blocks)
|
||||
block_size * 10, # Exact multiple of block size
|
||||
block_size * 10 + 1, # One over exact multiple
|
||||
1500, # Whisper's typical max
|
||||
]
|
||||
|
||||
for enc_len in test_cases:
|
||||
scheduler = _create_encoder_decoder_scheduler(block_size=block_size)
|
||||
|
||||
request = create_requests(
|
||||
num_requests=1,
|
||||
num_tokens=100,
|
||||
mm_hashes_list=[[f"enc_{enc_len}"]],
|
||||
mm_positions=[[PlaceholderRange(offset=0, length=enc_len)]],
|
||||
req_ids=[f"req_{enc_len}"],
|
||||
)[0]
|
||||
|
||||
scheduler.add_request(request)
|
||||
output = scheduler.schedule()
|
||||
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
|
||||
actual_blocks = _get_num_cross_attn_blocks(scheduler, request.request_id)
|
||||
expected_blocks = ceil(enc_len / block_size)
|
||||
|
||||
# Number of blocks must be exactly ceil(enc_len / block_size).
|
||||
assert actual_blocks == expected_blocks, (
|
||||
f"Encoder length {enc_len}: expected {expected_blocks} blocks, "
|
||||
f"got {actual_blocks}"
|
||||
)
|
||||
|
||||
# Total available slots must be >= encoder tokens (no under-allocation).
|
||||
total_slots = actual_blocks * block_size
|
||||
assert total_slots >= enc_len, (
|
||||
f"Encoder length {enc_len}: total slots {total_slots} < "
|
||||
f"needed {enc_len} (under-allocation)"
|
||||
)
|
||||
|
||||
|
||||
def test_cross_attn_zero_blocks_without_encoder_inputs():
|
||||
"""Test that requests without encoder inputs get zero cross-attention
|
||||
blocks, even when the scheduler is configured for encoder-decoder."""
|
||||
block_size = 16
|
||||
scheduler = _create_encoder_decoder_scheduler(block_size=block_size)
|
||||
|
||||
# Create a text-only request (no mm_features).
|
||||
request = create_requests(
|
||||
num_requests=1,
|
||||
num_tokens=100,
|
||||
req_ids=["req_text_only"],
|
||||
)[0]
|
||||
|
||||
# Text-only request has no encoder inputs.
|
||||
assert not request.has_encoder_inputs
|
||||
|
||||
scheduler.add_request(request)
|
||||
output = scheduler.schedule()
|
||||
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
|
||||
# No cross-attention blocks should be allocated.
|
||||
actual_blocks = _get_num_cross_attn_blocks(scheduler, request.request_id)
|
||||
assert actual_blocks == 0, (
|
||||
f"Text-only request should have 0 cross-attn blocks, got {actual_blocks}"
|
||||
)
|
||||
|
||||
|
||||
def test_eagle3_mm_encoder_cache_with_shift():
|
||||
"""Test EAGLE3 encoder scheduling accounts for shift_computed_tokens.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user