[Attention] Blackwell FP8 MLA support with CUTLASS_MLA backend (#23289)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -108,10 +108,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
self._use_old_cutlass_mla = False
|
||||
force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None)
|
||||
if force_old_cutlass:
|
||||
@@ -182,11 +178,10 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
> 0), f"block num must be greater than 0, got {block_num}"
|
||||
assert block_num % (128 / PAGE_SIZE) == 0
|
||||
|
||||
# TODO(kaixih@nvidia): support fp8
|
||||
assert q_nope.dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
|
||||
torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
|
||||
f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got "
|
||||
f"{q_nope.dtype}.")
|
||||
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
||||
assert (
|
||||
seq_lens.dtype == torch.int32
|
||||
@@ -195,7 +190,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
page_table.dtype == torch.int32
|
||||
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
|
||||
dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
else q_nope.dtype)
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
||||
|
||||
ops.sm100_cutlass_mla_decode(
|
||||
out,
|
||||
@@ -220,9 +217,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
|
||||
|
||||
# Adjust workspace size (if necessary)
|
||||
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
|
||||
|
||||
@@ -252,8 +246,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user