[Perf][Deepseek] optimize gather_and_maybe_dequant_cache kernel's perf for extremely long sequence (#28029)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone
2025-11-25 10:05:46 +08:00
committed by GitHub
parent 6f1355a1b7
commit 77e10c9cab
6 changed files with 131 additions and 105 deletions

View File

@@ -2201,7 +2201,8 @@ def gather_and_maybe_dequant_cache(
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
token_to_seq: torch.Tensor,
num_tokens: int,
kv_cache_dtype: str,
scale: torch.Tensor,
seq_starts: torch.Tensor | None = None,
@@ -2211,7 +2212,8 @@ def gather_and_maybe_dequant_cache(
dst,
block_table,
cu_seq_lens,
batch_size,
token_to_seq,
num_tokens,
kv_cache_dtype,
scale,
seq_starts,

View File

@@ -340,6 +340,8 @@ class MLACommonPrefillMetadata:
max_seq_lens: list[int]
seq_lens: torch.Tensor
workspace: torch.Tensor
token_to_seq: torch.Tensor
chunk_total_token: list[int]
# for mla DCP
padded_local_chunk_seq_lens: list[list[int]] | None = None
@@ -839,6 +841,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
torch.cumsum(
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
)
chunk_total_token = cu_seq_lens_cpu[:, -1]
max_token_num_over_chunk = chunk_total_token.max().item()
token_to_seq_tensor_cpu = torch.zeros(
[num_chunks, max_token_num_over_chunk], dtype=torch.int32
)
range_idx = torch.arange(num_prefills, dtype=torch.int32)
for i in range(num_chunks):
chunk_token_to_seq_tensor = torch.repeat_interleave(
range_idx, chunk_seq_lens[i]
)
chunk_len = chunk_token_to_seq_tensor.shape[0]
token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor
if self.dcp_world_size > 1:
local_context_lens_allranks = get_dcp_local_seq_lens(
@@ -906,6 +921,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
token_to_seq=token_to_seq_tensor_cpu.to(
device, non_blocking=True
),
chunk_total_token=chunk_total_token.tolist(),
workspace=self.chunked_prefill_workspace,
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(),
@@ -922,6 +941,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
token_to_seq=token_to_seq_tensor_cpu.to(
device, non_blocking=True
),
chunk_total_token=chunk_total_token,
workspace=self.chunked_prefill_workspace,
)
@@ -1638,16 +1661,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i],