[Fix][FlexAttention] return max logical block index to handle reused blocks (#30915)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
This commit is contained in:
Yifan Qiao
2025-12-17 22:42:21 -08:00
committed by GitHub
parent e3ab93c896
commit 11a89cf95c
2 changed files with 42 additions and 4 deletions

View File

@@ -15,7 +15,10 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
create_vllm_config,
)
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadataBuilder
from vllm.v1.attention.backends.flex_attention import (
FlexAttentionMetadataBuilder,
physical_to_logical_mapping,
)
from ..models.utils import check_embeddings_close, check_logprobs_close
@@ -205,5 +208,31 @@ def test_block_mask_direct_vs_slow_path():
)
def test_physical_to_logical_mapping_handles_reused_blocks():
"""Regression test: reused physical blocks map to the latest logical block.
For sliding-window / hybrid attention layers, physical KV-cache blocks can be
reused over time. The inverse mapping must therefore select the latest
logical block index for a physical block id.
"""
# Padding should not make physical block 0 look live.
block_table = torch.tensor([[6, 0, 0, 0]], dtype=torch.int32)
seq_lens = torch.tensor([1 * 16], dtype=torch.int32) # only 1 block valid
out = physical_to_logical_mapping(
block_table=block_table, seq_lens=seq_lens, block_size=16, total_blocks=10
)
assert out[0, 0].item() == -1
assert out[0, 6].item() == 0
# If a physical block id appears multiple times (block reuse), mapping should
# point to the latest logical block index.
block_table2 = torch.tensor([[2, 2, 5]], dtype=torch.int32)
seq_lens2 = torch.tensor([3 * 16], dtype=torch.int32)
out2 = physical_to_logical_mapping(
block_table=block_table2, seq_lens=seq_lens2, block_size=16, total_blocks=8
)
assert out2[0, 2].item() == 1
if __name__ == "__main__":
pytest.main([__file__])