[Attention] Blackwell FP8 MLA support with CUTLASS_MLA backend (#23289)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2025-09-03 14:05:24 -04:00
committed by GitHub
parent 731a6940e3
commit a742322092
4 changed files with 184 additions and 105 deletions

View File

@@ -108,10 +108,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for "
"CutlassMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CutlassMLA V1 with FP8 KV cache not yet supported")
self._use_old_cutlass_mla = False
force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None)
if force_old_cutlass:
@@ -182,11 +178,10 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
> 0), f"block num must be greater than 0, got {block_num}"
assert block_num % (128 / PAGE_SIZE) == 0
# TODO(kaixih@nvidia): support fp8
assert q_nope.dtype in (
torch.float16,
torch.bfloat16,
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got "
f"{q_nope.dtype}.")
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
assert (
seq_lens.dtype == torch.int32
@@ -195,7 +190,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
page_table.dtype == torch.int32
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype)
else q_nope.dtype)
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
ops.sm100_cutlass_mla_decode(
out,
@@ -220,9 +217,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
# Adjust workspace size (if necessary)
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
@@ -252,8 +246,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA")
B = q_nope.shape[0]