[Attention] MLA support for V1 (#13789)
Signed-off-by: Yang Chen <yangche@fb.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -14,6 +14,11 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
@@ -40,6 +45,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
|
||||
return FlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
@@ -85,6 +94,62 @@ class FlashAttentionMetadata:
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder:
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner"):
|
||||
self.runner = runner
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput"):
|
||||
pass
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int):
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.runner.device, non_blocking=True)
|
||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
|
||||
non_blocking=True)
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
# TODO: Optimize.
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
@@ -371,4 +436,4 @@ def cascade_attention(
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
|
||||
suffix_lse)
|
||||
suffix_lse)
|
||||
|
||||
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
1022
vllm/v1/attention/backends/mla/common.py
Normal file
1022
vllm/v1/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
110
vllm/v1/attention/backends/triton_mla.py
Normal file
110
vllm/v1/attention/backends/triton_mla.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
num_kv_splits = 4 # TODO: heuristic
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
B,
|
||||
self.num_heads,
|
||||
num_kv_splits,
|
||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# just mirror that
|
||||
self.kv_lora_rank + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# Run MQA
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
attn_metadata.block_table, attn_metadata.seq_lens,
|
||||
attn_logits, num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
Reference in New Issue
Block a user