Remove old cutlass mla (#23961)

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2025-09-17 10:31:43 -04:00
committed by GitHub
parent 47f670b03b
commit 8f3616f422
6 changed files with 10 additions and 345 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
@@ -109,12 +109,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for "
"CutlassMLAImpl")
self._use_old_cutlass_mla = False
force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None)
if force_old_cutlass:
logger.warning_once("Forcing old cutlass mla kernel")
self._use_old_cutlass_mla = True
# TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
# issues. In case the code hangs, use:
# FORCE_NUM_KV_SPLITS=1
@@ -219,16 +213,22 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
return out, returned_lse
def _sm100_forward_decode(
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
# Adjust workspace size (if necessary)
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
@@ -245,57 +245,3 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
)
return o, (lse if self.need_to_return_lse_for_decode else None)
# TODO: Currently we leave it here only for backup in case something is
# wrong with the new SM100 CUTLASS MLA kernel
def _old_forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA")
B = q_nope.shape[0]
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
dtype=q_nope.dtype,
device=q_nope.device)
# Run MLA
# Clone q_nope and q_pe to make sure strides computation is correct.
q_nope = q_nope.clone()
q_pe = q_pe.clone()
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale)
return o
def _forward_decode(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if self._use_old_cutlass_mla:
# TODO: Remove the old cutlass MLA kernel after more extensive
# testing
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata), None
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata)