[Misc] Add unit tests for chunked local attention (#21692)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin
2025-07-31 11:39:16 -07:00
committed by GitHub
parent 9e0726e5bf
commit 71470bc4af
2 changed files with 219 additions and 13 deletions

View File

@@ -40,7 +40,8 @@ def create_common_attn_metadata(
batch_spec: BatchSpec,
block_size: int,
device: torch.device,
max_block_idx: int = 1000) -> CommonAttentionMetadata:
max_block_idx: int = 1000,
arange_block_indices: bool = False) -> CommonAttentionMetadata:
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
# Create query start locations
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
@@ -65,19 +66,28 @@ def create_common_attn_metadata(
]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
# Create block table (random for testing)
# Create block table and slot mapping
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
block_table_tensor = torch.randint(0,
max_block_idx,
(batch_spec.batch_size, max_blocks),
dtype=torch.int32,
device=device)
# Create slot mapping
slot_mapping = torch.randint(0,
max_block_idx, (num_tokens, ),
dtype=torch.int64,
device=device)
if arange_block_indices:
num_blocks = batch_spec.batch_size * max_blocks
block_table_tensor = torch.arange(num_blocks,
dtype=torch.int32,
device=device).view(
batch_spec.batch_size,
max_blocks)
slot_mapping = torch.arange(num_tokens,
dtype=torch.int64,
device=device).view(num_tokens)
else:
block_table_tensor = torch.randint(0,
max_block_idx,
(batch_spec.batch_size, max_blocks),
dtype=torch.int32,
device=device)
slot_mapping = torch.randint(0,
max_block_idx, (num_tokens, ),
dtype=torch.int64,
device=device)
# Calculate max query length
max_query_len = max(batch_spec.query_lens)