[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)

Signed-off-by: hongchao <hongchao@msh.team>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: hongchao <hongchao@msh.team>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
yzds
2025-09-06 13:24:05 +08:00
committed by GitHub
parent 3c529fc994
commit ac201a0eaf
27 changed files with 999 additions and 230 deletions

View File

@@ -693,16 +693,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
cache_ops.def(
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor cp_local_token_select_indices,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA,
&cp_fused_concat_and_cache_mla);
// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "