[Perf][Deepseek] optimize gather_and_maybe_dequant_cache kernel's perf for extremely long sequence (#28029)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
@@ -2201,7 +2201,8 @@ def gather_and_maybe_dequant_cache(
|
||||
dst: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cu_seq_lens: torch.Tensor,
|
||||
batch_size: int,
|
||||
token_to_seq: torch.Tensor,
|
||||
num_tokens: int,
|
||||
kv_cache_dtype: str,
|
||||
scale: torch.Tensor,
|
||||
seq_starts: torch.Tensor | None = None,
|
||||
@@ -2211,7 +2212,8 @@ def gather_and_maybe_dequant_cache(
|
||||
dst,
|
||||
block_table,
|
||||
cu_seq_lens,
|
||||
batch_size,
|
||||
token_to_seq,
|
||||
num_tokens,
|
||||
kv_cache_dtype,
|
||||
scale,
|
||||
seq_starts,
|
||||
|
||||
@@ -340,6 +340,8 @@ class MLACommonPrefillMetadata:
|
||||
max_seq_lens: list[int]
|
||||
seq_lens: torch.Tensor
|
||||
workspace: torch.Tensor
|
||||
token_to_seq: torch.Tensor
|
||||
chunk_total_token: list[int]
|
||||
|
||||
# for mla DCP
|
||||
padded_local_chunk_seq_lens: list[list[int]] | None = None
|
||||
@@ -839,6 +841,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
torch.cumsum(
|
||||
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
|
||||
)
|
||||
chunk_total_token = cu_seq_lens_cpu[:, -1]
|
||||
|
||||
max_token_num_over_chunk = chunk_total_token.max().item()
|
||||
token_to_seq_tensor_cpu = torch.zeros(
|
||||
[num_chunks, max_token_num_over_chunk], dtype=torch.int32
|
||||
)
|
||||
range_idx = torch.arange(num_prefills, dtype=torch.int32)
|
||||
for i in range(num_chunks):
|
||||
chunk_token_to_seq_tensor = torch.repeat_interleave(
|
||||
range_idx, chunk_seq_lens[i]
|
||||
)
|
||||
chunk_len = chunk_token_to_seq_tensor.shape[0]
|
||||
token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
local_context_lens_allranks = get_dcp_local_seq_lens(
|
||||
@@ -906,6 +921,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
token_to_seq=token_to_seq_tensor_cpu.to(
|
||||
device, non_blocking=True
|
||||
),
|
||||
chunk_total_token=chunk_total_token.tolist(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
||||
local_context_lens_allranks=local_context_lens_allranks.tolist(),
|
||||
@@ -922,6 +941,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
token_to_seq=token_to_seq_tensor_cpu.to(
|
||||
device, non_blocking=True
|
||||
),
|
||||
chunk_total_token=chunk_total_token,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
|
||||
@@ -1638,16 +1661,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
output = None
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
workspace = prefill_metadata.chunked_context.workspace
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
|
||||
ops.gather_and_maybe_dequant_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
|
||||
num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
|
||||
Reference in New Issue
Block a user