[FlexAttention] allow custom mask mod (#37692)

Signed-off-by: Angel Li <liangel@meta.com>
This commit is contained in:
liangel-02
2026-03-24 16:03:24 -04:00
committed by GitHub
parent 54b0578ada
commit 8c47fdfdb1
2 changed files with 121 additions and 16 deletions

View File

@@ -14,6 +14,7 @@ from tests.v1.attention.utils import (
create_vllm_config,
)
from vllm.v1.attention.backends.flex_attention import (
BlockSparsityHint,
FlexAttentionMetadataBuilder,
physical_to_logical_mapping,
)
@@ -223,5 +224,55 @@ def test_physical_to_logical_mapping_handles_reused_blocks():
assert out2[0, 2].item() == 1
@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
reason="CUDA not available or PyTorch version < 2.9",
)
def test_block_sparsity_hint_prunes_blocks():
"""Test that BlockSparsityHint prunes KV blocks from the direct build path.
Uses a hint that only keeps the diagonal (q_block == kv_block) to verify
that off-diagonal blocks are excluded from the resulting BlockMask.
"""
device = torch.device("cuda")
vllm_config = create_vllm_config(
model_name="facebook/opt-125m",
block_size=16,
max_model_len=1024,
)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
batch_spec = BatchSpec(
seq_lens=[256],
query_lens=[256],
name="test_sparsity_hint",
)
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device
)
builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device)
metadata_no_hint = builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
metadata_no_hint.block_mask = metadata_no_hint._build_block_mask_direct()
assert metadata_no_hint.block_mask.kv_num_blocks.max().item() > 1
def diagonal_hint(q_block_idx, kv_block_idx, block_size):
return q_block_idx == kv_block_idx
metadata_with_hint = builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
metadata_with_hint.block_sparsity_hint = BlockSparsityHint(
hint_fn=diagonal_hint,
)
metadata_with_hint.block_mask = metadata_with_hint._build_block_mask_direct()
assert metadata_with_hint.block_mask.kv_num_blocks.max().item() <= 1
if __name__ == "__main__":
pytest.main([__file__])