[Attention] MLA with chunked prefill (#12639)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Patrick Horn <patrick.horn@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -682,8 +682,6 @@ def test_swap_blocks_mla(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_cache, dst_cache, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
cond=(kv_lora_rank == KV_LORA_RANKS[0]
|
||||
and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]),
|
||||
)
|
||||
|
||||
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
|
||||
@@ -694,3 +692,76 @@ def test_swap_blocks_mla(
|
||||
dst_cache[dst].cpu(),
|
||||
msg=f"Block {src} from src should have been swapped to block "
|
||||
f"{dst} in dst_cache.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", [512])
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", [64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("max_seq_len", [512])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype",
|
||||
["auto"]) # You can also test "fp8" if needed.
|
||||
@pytest.mark.parametrize("align_cache", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
num_blocks, max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, align_cache, device):
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
|
||||
seq_len_tensor = torch.randint(0,
|
||||
max_seq_len + 1, (batch_size, ),
|
||||
device=device)
|
||||
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
cu_seq_lens = torch.empty((batch_size + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
cu_seq_lens[0] = 0
|
||||
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
|
||||
print("seq_len_tensor", seq_len_tensor)
|
||||
|
||||
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
|
||||
block_table = torch.empty((batch_size, num_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
for b in range(batch_size):
|
||||
perm = torch.randperm(num_blocks, device=device)
|
||||
block_table[b, :] = perm
|
||||
|
||||
dst = torch.zeros((total_tokens, entry_size),
|
||||
dtype=src_cache.dtype,
|
||||
device=device)
|
||||
|
||||
expected_batches = []
|
||||
for b in range(batch_size):
|
||||
s = seq_len_tensor[b]
|
||||
if s == 0:
|
||||
continue
|
||||
tot = tot_blocks_tensor[b]
|
||||
blocks = block_table[b, :tot].tolist()
|
||||
|
||||
gathered_rows = []
|
||||
for i in range(tot - 1):
|
||||
gathered_rows.append(src_cache[blocks[i]])
|
||||
remaining = s - (tot - 1) * block_size
|
||||
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
|
||||
|
||||
batch_expected = torch.cat(gathered_rows, dim=0)
|
||||
expected_batches.append(batch_expected)
|
||||
expected = torch.cat(expected_batches, dim=0)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.gather_cache,
|
||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
||||
Reference in New Issue
Block a user