[Kernel] Support decode context parallelism on Blackwell with CUTLASS MLA (#24385)
Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -76,6 +76,7 @@ g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
|
||||
|
||||
|
||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -138,7 +139,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
workspace: torch.Tensor,
|
||||
sm_scale: float,
|
||||
num_kv_splits: int,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert (q_nope.ndim == 3
|
||||
), f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||
assert (
|
||||
@@ -193,9 +194,13 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
else q_nope.dtype)
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
||||
lse = (torch.empty(
|
||||
(B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
|
||||
if self.need_to_return_lse_for_decode else torch.Tensor())
|
||||
|
||||
ops.sm100_cutlass_mla_decode(
|
||||
out,
|
||||
lse,
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
@@ -205,7 +210,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
sm_scale,
|
||||
num_kv_splits,
|
||||
)
|
||||
return out[:, :H].contiguous()
|
||||
returned_lse = lse[:, :H].contiguous(
|
||||
) if self.need_to_return_lse_for_decode else lse
|
||||
return out[:, :H].contiguous(), returned_lse
|
||||
|
||||
def _sm100_forward_decode(
|
||||
self,
|
||||
@@ -213,7 +220,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
@@ -226,13 +233,18 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_nope = q_nope.clone()
|
||||
q_pe = q_pe.clone()
|
||||
|
||||
o = self._sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table,
|
||||
self._workspace.get_buf(),
|
||||
self.scale, self._num_kv_splits)
|
||||
o, lse = self._sm100_cutlass_mla_decode(
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table,
|
||||
self._workspace.get_buf(),
|
||||
self.scale,
|
||||
self._num_kv_splits,
|
||||
)
|
||||
|
||||
return o
|
||||
return o, (lse if self.need_to_return_lse_for_decode else None)
|
||||
|
||||
# TODO: Currently we leave it here only for backup in case something is
|
||||
# wrong with the new SM100 CUTLASS MLA kernel
|
||||
@@ -286,4 +298,4 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata), None
|
||||
|
||||
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata), None
|
||||
attn_metadata)
|
||||
|
||||
Reference in New Issue
Block a user