[Perf][Deepseek] optimize gather_and_maybe_dequant_cache kernel's perf for extremely long sequence (#28029)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
@@ -921,12 +921,16 @@ def test_gather_and_maybe_dequant_cache_mla(
|
||||
)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
|
||||
seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
|
||||
seq_len_tensor = torch.randint(
|
||||
max_seq_len, max_seq_len + 1, (batch_size,), device=device
|
||||
)
|
||||
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
|
||||
cu_seq_lens[0] = 0
|
||||
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
|
||||
token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
|
||||
token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
|
||||
print("seq_len_tensor", seq_len_tensor)
|
||||
|
||||
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
|
||||
@@ -977,7 +981,8 @@ def test_gather_and_maybe_dequant_cache_mla(
|
||||
dst,
|
||||
block_table,
|
||||
cu_seq_lens,
|
||||
batch_size,
|
||||
token_to_seq,
|
||||
total_tokens,
|
||||
kv_cache_dtype,
|
||||
scale,
|
||||
None,
|
||||
@@ -990,7 +995,8 @@ def test_gather_and_maybe_dequant_cache_mla(
|
||||
dst,
|
||||
block_table,
|
||||
cu_seq_lens,
|
||||
batch_size,
|
||||
token_to_seq,
|
||||
total_tokens,
|
||||
kv_cache_dtype,
|
||||
scale,
|
||||
None,
|
||||
|
||||
Reference in New Issue
Block a user