[Misc] Add unit tests for chunked local attention (#21692)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -40,7 +40,8 @@ def create_common_attn_metadata(
|
||||
batch_spec: BatchSpec,
|
||||
block_size: int,
|
||||
device: torch.device,
|
||||
max_block_idx: int = 1000) -> CommonAttentionMetadata:
|
||||
max_block_idx: int = 1000,
|
||||
arange_block_indices: bool = False) -> CommonAttentionMetadata:
|
||||
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
|
||||
# Create query start locations
|
||||
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
|
||||
@@ -65,19 +66,28 @@ def create_common_attn_metadata(
|
||||
]
|
||||
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
|
||||
|
||||
# Create block table (random for testing)
|
||||
# Create block table and slot mapping
|
||||
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
|
||||
block_table_tensor = torch.randint(0,
|
||||
max_block_idx,
|
||||
(batch_spec.batch_size, max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# Create slot mapping
|
||||
slot_mapping = torch.randint(0,
|
||||
max_block_idx, (num_tokens, ),
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
if arange_block_indices:
|
||||
num_blocks = batch_spec.batch_size * max_blocks
|
||||
block_table_tensor = torch.arange(num_blocks,
|
||||
dtype=torch.int32,
|
||||
device=device).view(
|
||||
batch_spec.batch_size,
|
||||
max_blocks)
|
||||
slot_mapping = torch.arange(num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device).view(num_tokens)
|
||||
else:
|
||||
block_table_tensor = torch.randint(0,
|
||||
max_block_idx,
|
||||
(batch_spec.batch_size, max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
slot_mapping = torch.randint(0,
|
||||
max_block_idx, (num_tokens, ),
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
|
||||
# Calculate max query length
|
||||
max_query_len = max(batch_spec.query_lens)
|
||||
|
||||
Reference in New Issue
Block a user