[BugFix][V1][ROCm] Triton MLA uses V0 backend on V1 engine (#19067)
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
This commit is contained in:
@@ -640,7 +640,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
@@ -672,11 +671,17 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
maybe_padded_v = torch.nn.functional.pad(
|
||||
v, [0, q.shape[-1] - v.shape[-1]], value=0)
|
||||
|
||||
if is_vllm_fa:
|
||||
kwargs["return_softmax_lse"] = return_softmax_lse
|
||||
else:
|
||||
# ROCm leverages the upstream flash_attn, which takes a parameter
|
||||
# called "return_attn_probs" instead of return_softmax_lse
|
||||
kwargs["return_attn_probs"] = return_softmax_lse
|
||||
|
||||
attn_out = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=maybe_padded_v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -5,10 +5,14 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
@@ -68,6 +72,59 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
|
||||
self.triton_fa_func = triton_attention if HAS_TRITON else None
|
||||
|
||||
def _flash_attn_varlen_diff_headdims_rocm(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
assert self.triton_fa_func is not None
|
||||
|
||||
# Triton Attention requires a padded V
|
||||
padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
# The output of triton_attention is a tuple of
|
||||
# [output_tensor, encoded_softmax] where encoded_softmax is always None
|
||||
output_tensor, _ = self.triton_fa_func(
|
||||
q,
|
||||
k,
|
||||
padded_v,
|
||||
None, # output
|
||||
kwargs["cu_seqlens_q"],
|
||||
kwargs["cu_seqlens_k"],
|
||||
kwargs["max_seqlen_q"],
|
||||
kwargs["max_seqlen_k"],
|
||||
kwargs["causal"],
|
||||
softmax_scale,
|
||||
None, # bias
|
||||
)
|
||||
|
||||
return output_tensor
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=False,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
if current_platform.is_rocm() \
|
||||
and self.use_triton_flash_attn \
|
||||
and not return_softmax_lse:
|
||||
return self._flash_attn_varlen_diff_headdims_rocm(
|
||||
q, k, v, softmax_scale=softmax_scale, **kwargs)
|
||||
else:
|
||||
return super()._flash_attn_varlen_diff_headdims(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user