[Kernel] cuda kernels for upcoming decode context parallel feature (#23791)
Co-authored-by: hongchao <hongchao@msh.team>
This commit is contained in:
@@ -790,6 +790,78 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", [512])
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", [64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("max_seq_len", [512])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype",
|
||||
["auto"]) # You can also test "fp8" if needed.
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
num_blocks, max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, device):
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
|
||||
seq_len_tensor = torch.randint(0,
|
||||
max_seq_len + 1, (batch_size, ),
|
||||
device=device)
|
||||
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
cu_seq_lens = torch.empty((batch_size + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
cu_seq_lens[0] = 0
|
||||
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
|
||||
print("seq_len_tensor", seq_len_tensor)
|
||||
|
||||
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
|
||||
block_table = torch.empty((batch_size, num_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
for b in range(batch_size):
|
||||
perm = torch.randperm(num_blocks, device=device)
|
||||
block_table[b, :] = perm
|
||||
|
||||
dst = torch.zeros((total_tokens, entry_size),
|
||||
dtype=src_cache.dtype,
|
||||
device=device)
|
||||
|
||||
expected_batches = []
|
||||
for b in range(batch_size):
|
||||
s = seq_len_tensor[b]
|
||||
if s == 0:
|
||||
continue
|
||||
tot = tot_blocks_tensor[b]
|
||||
blocks = block_table[b, :tot].tolist()
|
||||
|
||||
gathered_rows = []
|
||||
for i in range(tot - 1):
|
||||
gathered_rows.append(src_cache[blocks[i]])
|
||||
remaining = s - (tot - 1) * block_size
|
||||
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
|
||||
|
||||
batch_expected = torch.cat(gathered_rows, dim=0)
|
||||
expected_batches.append(batch_expected)
|
||||
expected = torch.cat(expected_batches, dim=0)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.cp_gather_cache,
|
||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
||||
|
||||
Reference in New Issue
Block a user