[FlexAttention] allow custom mask mod (#37692)
Signed-off-by: Angel Li <liangel@meta.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from tests.v1.attention.utils import (
|
||||
create_vllm_config,
|
||||
)
|
||||
from vllm.v1.attention.backends.flex_attention import (
|
||||
BlockSparsityHint,
|
||||
FlexAttentionMetadataBuilder,
|
||||
physical_to_logical_mapping,
|
||||
)
|
||||
@@ -223,5 +224,55 @@ def test_physical_to_logical_mapping_handles_reused_blocks():
|
||||
assert out2[0, 2].item() == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.9",
|
||||
)
|
||||
def test_block_sparsity_hint_prunes_blocks():
|
||||
"""Test that BlockSparsityHint prunes KV blocks from the direct build path.
|
||||
|
||||
Uses a hint that only keeps the diagonal (q_block == kv_block) to verify
|
||||
that off-diagonal blocks are excluded from the resulting BlockMask.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
model_name="facebook/opt-125m",
|
||||
block_size=16,
|
||||
max_model_len=1024,
|
||||
)
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[256],
|
||||
query_lens=[256],
|
||||
name="test_sparsity_hint",
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device
|
||||
)
|
||||
|
||||
builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device)
|
||||
|
||||
metadata_no_hint = builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
metadata_no_hint.block_mask = metadata_no_hint._build_block_mask_direct()
|
||||
assert metadata_no_hint.block_mask.kv_num_blocks.max().item() > 1
|
||||
|
||||
def diagonal_hint(q_block_idx, kv_block_idx, block_size):
|
||||
return q_block_idx == kv_block_idx
|
||||
|
||||
metadata_with_hint = builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
metadata_with_hint.block_sparsity_hint = BlockSparsityHint(
|
||||
hint_fn=diagonal_hint,
|
||||
)
|
||||
metadata_with_hint.block_mask = metadata_with_hint._build_block_mask_direct()
|
||||
assert metadata_with_hint.block_mask.kv_num_blocks.max().item() <= 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user