[Kernel] Support decode context parallelism on Blackwell with CUTLASS MLA (#24385)

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Ming Yang
2025-09-07 18:27:12 -07:00
committed by GitHub
parent 795b6951cd
commit 86173ad593
5 changed files with 63 additions and 32 deletions

View File

@@ -76,6 +76,7 @@ g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
@@ -138,7 +139,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
workspace: torch.Tensor,
sm_scale: float,
num_kv_splits: int,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
assert (q_nope.ndim == 3
), f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
assert (
@@ -193,9 +194,13 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype)
else q_nope.dtype)
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
lse = (torch.empty(
(B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
if self.need_to_return_lse_for_decode else torch.Tensor())
ops.sm100_cutlass_mla_decode(
out,
lse,
q_nope,
q_pe,
kv_c_and_k_pe_cache,
@@ -205,7 +210,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
sm_scale,
num_kv_splits,
)
return out[:, :H].contiguous()
returned_lse = lse[:, :H].contiguous(
) if self.need_to_return_lse_for_decode else lse
return out[:, :H].contiguous(), returned_lse
def _sm100_forward_decode(
self,
@@ -213,7 +220,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
@@ -226,13 +233,18 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_nope = q_nope.clone()
q_pe = q_pe.clone()
o = self._sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table,
self._workspace.get_buf(),
self.scale, self._num_kv_splits)
o, lse = self._sm100_cutlass_mla_decode(
q_nope,
q_pe,
kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table,
self._workspace.get_buf(),
self.scale,
self._num_kv_splits,
)
return o
return o, (lse if self.need_to_return_lse_for_decode else None)
# TODO: Currently we leave it here only for backup in case something is
# wrong with the new SM100 CUTLASS MLA kernel
@@ -286,4 +298,4 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata), None
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata), None
attn_metadata)