[Attention] Make local attention backend agnostic (#21093)
This commit is contained in:
@@ -272,11 +272,14 @@ def infer_global_hyperparameters(
|
||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||
def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
query_start_loc_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
block_size: int = 0,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||
) -> CommonAttentionMetadata:
|
||||
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
|
||||
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
|
||||
@@ -339,6 +342,7 @@ def make_local_attention_virtual_batches(
|
||||
attn_chunk_size,
|
||||
dtype=np.int32)
|
||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||
num_computed_tokens_local = seqlens_k_local - seqlens_q_local
|
||||
|
||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
||||
(rarange * attn_chunk_size + \
|
||||
@@ -380,8 +384,22 @@ def make_local_attention_virtual_batches(
|
||||
block_table_local = block_table[batch_indices, block_indices]\
|
||||
.view(virtual_batches, -1)
|
||||
|
||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
||||
block_table_local
|
||||
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
|
||||
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
query_start_loc=query_start_loc_cpu.to(device=device,
|
||||
non_blocking=True),
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
|
||||
num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
|
||||
num_reqs=len(seq_lens_cpu),
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
max_query_len=seqlens_q_local.max(),
|
||||
block_table_tensor=block_table_local,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
)
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
|
||||
Reference in New Issue
Block a user