Remove old cutlass mla (#23961)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -109,12 +109,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl")
|
||||
|
||||
self._use_old_cutlass_mla = False
|
||||
force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None)
|
||||
if force_old_cutlass:
|
||||
logger.warning_once("Forcing old cutlass mla kernel")
|
||||
self._use_old_cutlass_mla = True
|
||||
|
||||
# TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
|
||||
# issues. In case the code hangs, use:
|
||||
# FORCE_NUM_KV_SPLITS=1
|
||||
@@ -219,16 +213,22 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
return out, returned_lse
|
||||
|
||||
def _sm100_forward_decode(
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# Adjust workspace size (if necessary)
|
||||
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
|
||||
|
||||
@@ -245,57 +245,3 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
)
|
||||
|
||||
return o, (lse if self.need_to_return_lse_for_decode else None)
|
||||
|
||||
# TODO: Currently we leave it here only for backup in case something is
|
||||
# wrong with the new SM100 CUTLASS MLA kernel
|
||||
def _old_forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
# Run MLA
|
||||
# Clone q_nope and q_pe to make sure strides computation is correct.
|
||||
q_nope = q_nope.clone()
|
||||
q_pe = q_pe.clone()
|
||||
|
||||
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table, self.scale)
|
||||
|
||||
return o
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
if self._use_old_cutlass_mla:
|
||||
# TODO: Remove the old cutlass MLA kernel after more extensive
|
||||
# testing
|
||||
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata), None
|
||||
|
||||
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata)
|
||||
|
||||
Reference in New Issue
Block a user