[NVIDIA] Support Cutlass MLA for Blackwell GPUs (#16032)

Signed-off-by: kaixih <kaixih@nvidia.com>
This commit is contained in:
Kaixi Hou
2025-04-27 06:29:21 -07:00
committed by GitHub
parent 756848e79e
commit ed7a29d9f8
8 changed files with 403 additions and 5 deletions

View File

@@ -130,6 +130,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
// Compute MLA decode using cutlass.
ops.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, float scale) -> ()");
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(