[5/N][Attention] Finish eliminating vllm/attention folder (#32064)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
371
vllm/model_executor/layers/attention/mla_attention.py
Executable file → Normal file
371
vllm/model_executor/layers/attention/mla_attention.py
Executable file → Normal file
@@ -191,24 +191,38 @@ import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
from vllm.model_executor.layers.attention.attention import (
|
||||
_init_kv_cache_quant,
|
||||
get_attention_context,
|
||||
set_default_quant_scales,
|
||||
should_load_quant_weights,
|
||||
)
|
||||
from vllm.model_executor.layers.attention.kv_transfer_utils import (
|
||||
maybe_transfer_kv_layer,
|
||||
)
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
@@ -217,11 +231,16 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_nvidia_artifactory
|
||||
from vllm.utils.math_utils import cdiv, round_down
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
kv_cache_dtype_str_to_dtype,
|
||||
)
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MLAAttentionImpl,
|
||||
)
|
||||
@@ -234,7 +253,320 @@ from vllm.v1.attention.backends.utils import (
|
||||
)
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheSpec,
|
||||
MLAAttentionSpec,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
"""Multi-Head Latent Attention layer.
|
||||
|
||||
This class takes query, and compressed key/value tensors as input.
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_sparse: bool = False,
|
||||
indexer: object | None = None,
|
||||
**extra_impl_args,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.scale = scale
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.head_size = kv_lora_rank + qk_rope_head_dim
|
||||
self.layer_name = prefix
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=True,
|
||||
use_sparse=use_sparse,
|
||||
)
|
||||
|
||||
if (
|
||||
cache_config is not None
|
||||
and cache_config.enable_prefix_caching
|
||||
and vllm_is_batch_invariant()
|
||||
and (
|
||||
self.attn_backend.get_name() == "TRITON_MLA"
|
||||
or self.attn_backend.get_name() == "FLASHINFER"
|
||||
)
|
||||
):
|
||||
logger.warning_once(
|
||||
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
|
||||
"with batch invariance, as it is not yet supported.",
|
||||
scope="local",
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
|
||||
self.impl = impl_cls(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_size,
|
||||
scale=self.scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type=AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name=None,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
kv_b_proj=kv_b_proj,
|
||||
indexer=indexer,
|
||||
**extra_impl_args,
|
||||
)
|
||||
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(
|
||||
get_current_vllm_config().parallel_config.pipeline_parallel_size
|
||||
)
|
||||
]
|
||||
|
||||
self.use_sparse = use_sparse
|
||||
|
||||
# Initialize q/k/v range constants.
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
|
||||
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
self.impl.forward(
|
||||
self,
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return self.impl.forward(
|
||||
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
|
||||
)
|
||||
else:
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
torch.ops.vllm.unified_mla_attention_with_output(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output,
|
||||
self.layer_name,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return torch.ops.vllm.unified_mla_attention(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
self.layer_name,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# If we should not load quant weights, we initialize the scales to 1.0
|
||||
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
|
||||
# for more details.
|
||||
quant_method = (
|
||||
self.quant_config.get_quant_method(self, prefix=self.layer_name)
|
||||
if self.quant_config
|
||||
else None
|
||||
)
|
||||
if not should_load_quant_weights(quant_method):
|
||||
set_default_quant_scales(self, register_buffer=False)
|
||||
|
||||
def calc_kv_scales(
|
||||
self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
|
||||
) -> None:
|
||||
"""Optional scale calculation for MLA inputs.
|
||||
|
||||
Mirrors Attention.calc_kv_scales. Not all MLA backends require this
|
||||
"""
|
||||
# Use safe defaults if ranges are not present
|
||||
q_range = getattr(self, "q_range", torch.tensor(1.0))
|
||||
k_range = getattr(self, "k_range", torch.tensor(1.0))
|
||||
v_range = getattr(self, "v_range", torch.tensor(1.0))
|
||||
|
||||
self._q_scale.copy_(torch.abs(q).max() / q_range)
|
||||
# kv_c_normed is the compressed KV representation; use it for k/v
|
||||
kv_abs_max = torch.abs(kv_c_normed).max()
|
||||
self._k_scale.copy_(kv_abs_max / k_range)
|
||||
self._v_scale.copy_(kv_abs_max / v_range)
|
||||
self._q_scale_float = self._q_scale.item()
|
||||
self._k_scale_float = self._k_scale.item()
|
||||
self._v_scale_float = self._v_scale.item()
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
|
||||
self.kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
return MLAAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype,
|
||||
)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_mla_attention(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def unified_mla_attention_fake(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_mla_attention",
|
||||
op_func=unified_mla_attention,
|
||||
mutates_args=[],
|
||||
fake_impl=unified_mla_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_mla_attention_with_output(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
self.impl.forward(
|
||||
self,
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
)
|
||||
|
||||
|
||||
def unified_mla_attention_with_output_fake(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_mla_attention_with_output",
|
||||
op_func=unified_mla_attention_with_output,
|
||||
mutates_args=["output", "output_block_scale"],
|
||||
fake_impl=unified_mla_attention_with_output_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
class QueryLenSupport(Enum):
|
||||
@@ -266,15 +598,12 @@ except ImportError:
|
||||
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
|
||||
is_vllm_fa = False
|
||||
|
||||
try:
|
||||
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
|
||||
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401
|
||||
|
||||
flashinfer_available = True
|
||||
except ImportError:
|
||||
BatchPrefillWithRaggedKVCacheWrapper = object
|
||||
@functools.cache
|
||||
def flashinfer_available() -> bool:
|
||||
import importlib.util
|
||||
|
||||
flashinfer_available = False
|
||||
return importlib.util.find_spec("flashinfer") is not None
|
||||
|
||||
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
@@ -398,8 +727,8 @@ class MLACommonPrefillMetadata:
|
||||
|
||||
@dataclass
|
||||
class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
|
||||
prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
|
||||
prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = field(
|
||||
prefill_main: "BatchPrefillWithRaggedKVCacheWrapper | None" = None
|
||||
prefill_chunks: "list[BatchPrefillWithRaggedKVCacheWrapper]" = field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
@@ -495,7 +824,7 @@ def use_flashinfer_prefill() -> bool:
|
||||
vllm_config = get_current_vllm_config()
|
||||
if not (
|
||||
not vllm_config.attention_config.disable_flashinfer_prefill
|
||||
and flashinfer_available
|
||||
and flashinfer_available()
|
||||
and not vllm_config.attention_config.use_cudnn_prefill
|
||||
and current_platform.is_device_capability_family(100)
|
||||
):
|
||||
@@ -509,7 +838,7 @@ def use_cudnn_prefill() -> bool:
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return (
|
||||
flashinfer_available
|
||||
flashinfer_available()
|
||||
and vllm_config.attention_config.use_cudnn_prefill
|
||||
and current_platform.is_device_capability_family(100)
|
||||
and has_nvidia_artifactory()
|
||||
@@ -731,6 +1060,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
has_context = True
|
||||
|
||||
if self._fi_prefill_main is None:
|
||||
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
|
||||
|
||||
self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self._workspace_buffer, "NHD", backend="cutlass"
|
||||
)
|
||||
@@ -739,6 +1070,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
num_chunks = chunked_context.cu_seq_lens.shape[0]
|
||||
# Allocate more prefill chunk wrappers if needed
|
||||
if len(self._fi_prefill_chunks) < num_chunks:
|
||||
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
|
||||
|
||||
for _ in range(len(self._fi_prefill_chunks), num_chunks):
|
||||
self._fi_prefill_chunks.append(
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
@@ -1513,6 +1846,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
):
|
||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||
assert prefill.query_seq_lens is not None
|
||||
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
|
||||
|
||||
output, lse = cudnn_batch_prefill_with_kv_cache(
|
||||
q=q,
|
||||
k_cache=k,
|
||||
@@ -1572,6 +1907,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
assert prefill.chunked_context is not None
|
||||
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||
assert prefill.query_seq_lens is not None
|
||||
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
|
||||
|
||||
return cudnn_batch_prefill_with_kv_cache(
|
||||
q=q,
|
||||
k_cache=k,
|
||||
|
||||
Reference in New Issue
Block a user