[Perf][Deepseek] optimize gather_and_maybe_dequant_cache kernel's perf for extremely long sequence (#28029)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone
2025-11-25 10:05:46 +08:00
committed by GitHub
parent 6f1355a1b7
commit 77e10c9cab
6 changed files with 131 additions and 105 deletions

View File

@@ -921,12 +921,16 @@ def test_gather_and_maybe_dequant_cache_mla(
)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
seq_len_tensor = torch.randint(
max_seq_len, max_seq_len + 1, (batch_size,), device=device
)
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
@@ -977,7 +981,8 @@ def test_gather_and_maybe_dequant_cache_mla(
dst,
block_table,
cu_seq_lens,
batch_size,
token_to_seq,
total_tokens,
kv_cache_dtype,
scale,
None,
@@ -990,7 +995,8 @@ def test_gather_and_maybe_dequant_cache_mla(
dst,
block_table,
cu_seq_lens,
batch_size,
token_to_seq,
total_tokens,
kv_cache_dtype,
scale,
None,