[1/N][Attention] Restructure attention: move files (#31916)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
449
vllm/v1/attention/backend.py
Normal file
449
vllm/v1/attention/backend.py
Normal file
@@ -0,0 +1,449 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
|
||||
|
||||
class AttentionType(str, Enum):
|
||||
"""
|
||||
Attention type.
|
||||
Use string to be compatible with `torch.compile`.
|
||||
"""
|
||||
|
||||
DECODER = "decoder"
|
||||
"""Decoder attention between previous layer Q/K/V."""
|
||||
ENCODER = "encoder"
|
||||
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
|
||||
ENCODER_ONLY = "encoder_only"
|
||||
"""Encoder attention between previous layer Q/K/V."""
|
||||
ENCODER_DECODER = "encoder_decoder"
|
||||
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
||||
|
||||
|
||||
class MultipleOf:
|
||||
base: int
|
||||
|
||||
def __init__(self, base: int):
|
||||
self.base = base
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
|
||||
# For some attention backends, we allocate an output tensor before
|
||||
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(1)]
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_name() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_impl_cls() -> type["AttentionImpl"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
"""
|
||||
Get the physical (memory layout) ordering of the kv cache dimensions.
|
||||
e.g. if the KV cache shape is
|
||||
[2, num_blocks, block_size, num_heads, head_size],
|
||||
and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
|
||||
ordering of dimensions is
|
||||
[num_blocks, num_heads, 2, block_size, head_size].
|
||||
|
||||
If this function is unimplemented / raises NotImplementedError,
|
||||
the physical layout of the KV cache will match the logical shape.
|
||||
|
||||
Args:
|
||||
include_num_layers_dimension: if True, includes an additional
|
||||
num_layers dimension, which is assumed to be prepended
|
||||
to the logical KV cache shape.
|
||||
With the above example, a return value (2, 4, 0, 1, 3, 5)
|
||||
corresponds to
|
||||
[num_blocks, num_heads, num_layers, 2, block_size, head_size].
|
||||
|
||||
If an additional dimension is NOT included in the returned
|
||||
tuple, the physical layout will not include a layers dimension.
|
||||
|
||||
Returns:
|
||||
A tuple of ints which is a permutation of range(len(shape)).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def full_cls_name(cls) -> tuple[str, str]:
|
||||
return (cls.__module__, cls.__qualname__)
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
return (not supported_head_sizes) or head_size in supported_head_sizes
|
||||
|
||||
@classmethod
|
||||
def supports_dtype(cls, dtype: torch.dtype) -> bool:
|
||||
return dtype in cls.supported_dtypes
|
||||
|
||||
@classmethod
|
||||
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
|
||||
if kv_cache_dtype is None:
|
||||
return True
|
||||
return (not cls.supported_kv_cache_dtypes) or (
|
||||
kv_cache_dtype in cls.supported_kv_cache_dtypes
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
from vllm.config.cache import BlockSize
|
||||
|
||||
if block_size is None:
|
||||
return True
|
||||
|
||||
valid_sizes = get_args(BlockSize)
|
||||
if block_size not in valid_sizes:
|
||||
return False
|
||||
|
||||
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
|
||||
if not supported_kernel_block_sizes:
|
||||
return True
|
||||
|
||||
for supported_size in supported_kernel_block_sizes:
|
||||
if isinstance(supported_size, MultipleOf):
|
||||
supported_size = supported_size.base
|
||||
# With hybrid_blocks feature, the framework-level block size
|
||||
# only needs to be a multiple of the kernel's requirement,
|
||||
# even if the kernel requires a fixed block_size.
|
||||
if block_size % supported_size == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""Check if backend supports a given attention type.
|
||||
|
||||
By default, only supports decoder attention.
|
||||
Backends should override this to support other attention types.
|
||||
"""
|
||||
return attn_type == AttentionType.DECODER
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: "DeviceCapability",
|
||||
) -> str | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def validate_configuration(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
device_capability: "DeviceCapability",
|
||||
attn_type: str,
|
||||
) -> list[str]:
|
||||
invalid_reasons = []
|
||||
if not cls.supports_head_size(head_size):
|
||||
invalid_reasons.append("head_size not supported")
|
||||
if not cls.supports_dtype(dtype):
|
||||
invalid_reasons.append("dtype not supported")
|
||||
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
|
||||
invalid_reasons.append("kv_cache_dtype not supported")
|
||||
if not cls.supports_block_size(block_size):
|
||||
invalid_reasons.append("block_size not supported")
|
||||
if use_mm_prefix and not cls.supports_mm_prefix():
|
||||
invalid_reasons.append(
|
||||
"partial multimodal token full attention not supported"
|
||||
)
|
||||
if use_mla != cls.is_mla():
|
||||
if use_mla:
|
||||
invalid_reasons.append("MLA not supported")
|
||||
else:
|
||||
invalid_reasons.append("non-MLA not supported")
|
||||
if has_sink and not cls.supports_sink():
|
||||
invalid_reasons.append("sink setting not supported")
|
||||
if use_sparse != cls.is_sparse():
|
||||
if use_sparse:
|
||||
invalid_reasons.append("sparse not supported")
|
||||
else:
|
||||
invalid_reasons.append("non-sparse not supported")
|
||||
if not cls.supports_compute_capability(device_capability):
|
||||
invalid_reasons.append("compute capability not supported")
|
||||
if not cls.supports_attn_type(attn_type):
|
||||
invalid_reasons.append(f"attention type {attn_type} not supported")
|
||||
combination_reason = cls.supports_combination(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
)
|
||||
if combination_reason is not None:
|
||||
invalid_reasons.append(combination_reason)
|
||||
return invalid_reasons
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return None
|
||||
|
||||
|
||||
class AttentionMetadata:
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T", bound=AttentionMetadata)
|
||||
|
||||
|
||||
class AttentionLayer(Protocol):
|
||||
_q_scale: torch.Tensor
|
||||
_k_scale: torch.Tensor
|
||||
_v_scale: torch.Tensor
|
||||
_q_scale_float: float
|
||||
_k_scale_float: float
|
||||
_v_scale_float: float
|
||||
_prob_scale: torch.Tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class AttentionImpl(ABC, Generic[T]):
|
||||
# Required attributes that all impls should have
|
||||
num_heads: int
|
||||
head_size: int
|
||||
scale: float
|
||||
|
||||
# Whether the attention impl can return the softmax lse for decode.
|
||||
# Some features like decode context parallelism require the softmax lse.
|
||||
can_return_lse_for_decode: bool = False
|
||||
|
||||
# Whether the attention impl supports Prefill Context Parallelism.
|
||||
supports_pcp: bool = False
|
||||
# Whether the attention impl(or ops) supports MTP
|
||||
# when cp_kv_cache_interleave_size > 1
|
||||
supports_mtp_with_cp_non_trivial_interleave_size: bool = False
|
||||
|
||||
# some attention backends might not always want to return lse
|
||||
# even if they can return lse (for efficiency reasons)
|
||||
need_to_return_lse_for_decode: bool = False
|
||||
|
||||
# Whether this attention implementation supports pre-quantized query input.
|
||||
# When True, the attention layer will quantize queries before passing them
|
||||
# to this backend, allowing torch.compile to fuse the quantization with
|
||||
# previous operations. This is typically supported when using FP8 KV cache
|
||||
# with compatible attention kernels (e.g., TRT-LLM).
|
||||
# Subclasses should set this in __init__.
|
||||
# TODO add support to more backends:
|
||||
# https://github.com/vllm-project/vllm/issues/25584
|
||||
supports_quant_query_input: bool = False
|
||||
|
||||
dcp_world_size: int
|
||||
dcp_rank: int
|
||||
|
||||
pcp_world_size: int
|
||||
pcp_rank: int
|
||||
|
||||
total_cp_world_size: int
|
||||
total_cp_rank: int
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# use __new__ so that all subclasses will call this
|
||||
self = super().__new__(cls)
|
||||
try:
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
try:
|
||||
from vllm.distributed.parallel_state import get_pcp_group
|
||||
|
||||
self.pcp_world_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
self.pcp_world_size = 1
|
||||
self.pcp_rank = 0
|
||||
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
|
||||
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
|
||||
|
||||
self.need_to_return_lse_for_decode = (
|
||||
self.dcp_world_size > 1 and self.can_return_lse_for_decode
|
||||
)
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
sliding_window: int | None = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: "QuantKey"):
|
||||
"""
|
||||
Does this attention implementation support fused output quantization.
|
||||
This is used by the AttnFusionPass to only fuse output quantization
|
||||
onto implementations that support it.
|
||||
|
||||
:param quant_key: QuantKey object that describes the quantization op
|
||||
:return: is fusion supported for this type of quantization
|
||||
"""
|
||||
return False
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
pass
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: "ColumnParallelLinear",
|
||||
indexer: object | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_cq: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype != "auto"
|
||||
@@ -6,16 +6,16 @@ from typing import ClassVar
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
|
||||
130
vllm/v1/attention/backends/fa_utils.py
Normal file
130
vllm/v1/attention/backends/fa_utils.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm import _custom_ops
|
||||
|
||||
ops = _custom_ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
flash_attn_varlen_func,
|
||||
get_scheduler_metadata,
|
||||
)
|
||||
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
ops = ipex_ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
||||
get_scheduler_metadata = ops.get_scheduler_metadata
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Rocm platform requires upstream flash-attn "
|
||||
"to be installed. Please install flash-attn first."
|
||||
) from e
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_xpu():
|
||||
return 2
|
||||
if current_platform.is_rocm():
|
||||
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
|
||||
return None
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
device_capability = current_platform.get_device_capability()
|
||||
|
||||
assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = (
|
||||
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
||||
)
|
||||
|
||||
# 2. override if passed by environment or config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
vllm_config is not None
|
||||
and vllm_config.attention_config.flash_attn_version is not None
|
||||
):
|
||||
fa_version = vllm_config.attention_config.flash_attn_version
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 on Blackwell platform "
|
||||
"defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
if requires_alibi and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error(
|
||||
"Cannot use FA version %d is not supported due to %s",
|
||||
fa_version,
|
||||
fa_version_unsupported_reason(fa_version),
|
||||
)
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
|
||||
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
return (
|
||||
get_flash_attn_version() == 3
|
||||
and current_platform.is_device_capability_family(90)
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_supports_sinks() -> bool:
|
||||
if current_platform.is_xpu():
|
||||
return True
|
||||
else:
|
||||
return get_flash_attn_version() == 3
|
||||
|
||||
|
||||
def flash_attn_supports_mla():
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
return is_fa_version_supported(
|
||||
3
|
||||
) and current_platform.is_device_capability_family(90)
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def is_flash_attn_varlen_func_available() -> bool:
|
||||
return current_platform.is_cuda() or current_platform.is_xpu()
|
||||
@@ -9,24 +9,24 @@ from typing import ClassVar
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
from vllm.v1.attention.backends.fa_utils import (
|
||||
flash_attn_supports_fp8,
|
||||
get_flash_attn_version,
|
||||
is_flash_attn_varlen_func_available,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
|
||||
if is_flash_attn_varlen_func_available():
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
from vllm.v1.attention.backends.fa_utils import (
|
||||
flash_attn_supports_sinks,
|
||||
flash_attn_varlen_func,
|
||||
get_scheduler_metadata,
|
||||
|
||||
@@ -4,14 +4,14 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_diffkv,
|
||||
)
|
||||
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
|
||||
|
||||
if is_flash_attn_varlen_func_available():
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
|
||||
|
||||
@@ -19,14 +19,6 @@ from flashinfer.utils import FP4Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
@@ -48,6 +40,12 @@ from vllm.utils.flashinfer import (
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
@@ -59,6 +57,8 @@ from vllm.v1.attention.backends.utils import (
|
||||
infer_global_hyperparameters,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
|
||||
@@ -20,12 +20,6 @@ from torch.nn.attention.flex_attention import (
|
||||
or_masks,
|
||||
)
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
@@ -35,6 +29,12 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
|
||||
@@ -6,8 +6,8 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
AttentionCGSupport,
|
||||
|
||||
@@ -4,8 +4,8 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
|
||||
@@ -5,9 +5,9 @@ from dataclasses import dataclass, replace
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
|
||||
@@ -199,15 +199,6 @@ from tqdm import tqdm
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl,
|
||||
)
|
||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.logger import init_logger
|
||||
@@ -222,6 +213,13 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_nvidia_artifactory
|
||||
from vllm.utils.math_utils import cdiv, round_down
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
@@ -230,6 +228,8 @@ from vllm.v1.attention.backends.utils import (
|
||||
infer_global_hyperparameters,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
|
||||
@@ -7,15 +7,15 @@ from typing import ClassVar
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
|
||||
@@ -6,16 +6,6 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_mla,
|
||||
get_flash_attn_version,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
@@ -23,6 +13,16 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.backends.fa_utils import (
|
||||
flash_attn_supports_mla,
|
||||
get_flash_attn_version,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
|
||||
@@ -6,14 +6,14 @@ from typing import ClassVar
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
|
||||
@@ -6,12 +6,6 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
@@ -19,6 +13,7 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
@@ -32,6 +27,11 @@ from vllm.v1.attention.backends.utils import (
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
)
|
||||
from vllm.v1.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -78,11 +78,11 @@ class FlashMLABackend(MLACommonBackend):
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if use_sparse:
|
||||
from vllm.attention.ops.flashmla import is_flashmla_sparse_supported
|
||||
from vllm.v1.attention.ops.flashmla import is_flashmla_sparse_supported
|
||||
|
||||
return is_flashmla_sparse_supported()[1]
|
||||
else:
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
return is_flashmla_dense_supported()[1]
|
||||
|
||||
|
||||
@@ -7,17 +7,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
@@ -25,6 +14,12 @@ from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
@@ -35,6 +30,11 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.attention.ops.flashmla import (
|
||||
flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
|
||||
@@ -5,14 +5,14 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
|
||||
@@ -7,8 +7,8 @@ from typing import ClassVar
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import AttentionLayer, MultipleOf
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionLayer, MultipleOf
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
|
||||
@@ -9,13 +9,13 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
|
||||
@@ -5,23 +5,23 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
258
vllm/v1/attention/backends/registry.py
Normal file
258
vllm/v1/attention/backends/registry.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention backend registry"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from enum import Enum, EnumMeta
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _AttentionBackendEnumMeta(EnumMeta):
|
||||
"""Metaclass for AttentionBackendEnum to provide better error messages."""
|
||||
|
||||
def __getitem__(cls, name: str):
|
||||
"""Get backend by name with helpful error messages."""
|
||||
try:
|
||||
return super().__getitem__(name)
|
||||
except KeyError:
|
||||
members = cast("dict[str, Enum]", cls.__members__).keys()
|
||||
valid_backends = ", ".join(members)
|
||||
raise ValueError(
|
||||
f"Unknown attention backend: '{name}'. "
|
||||
f"Valid options are: {valid_backends}"
|
||||
) from None
|
||||
|
||||
|
||||
class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"""Enumeration of all supported attention backends.
|
||||
|
||||
The enum value is the default class path, but this can be overridden
|
||||
at runtime using register_backend().
|
||||
|
||||
To get the actual backend class (respecting overrides), use:
|
||||
backend.get_class()
|
||||
"""
|
||||
|
||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
FLASH_ATTN_DIFFKV = (
|
||||
"vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend"
|
||||
)
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
|
||||
ROCM_AITER_TRITON_MLA = (
|
||||
"vllm.v1.attention.backends.mla.aiter_triton_mla.AiterTritonMLABackend"
|
||||
)
|
||||
ROCM_AITER_FA = (
|
||||
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
)
|
||||
ROCM_AITER_MLA_SPARSE = (
|
||||
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
|
||||
)
|
||||
TORCH_SDPA = "" # this tag is only used for ViT
|
||||
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
||||
FLASHINFER_MLA = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||
)
|
||||
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
|
||||
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||||
FLASHMLA_SPARSE = (
|
||||
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
|
||||
)
|
||||
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
|
||||
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
|
||||
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
|
||||
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||
TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
|
||||
ROCM_AITER_UNIFIED_ATTN = (
|
||||
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
|
||||
"RocmAiterUnifiedAttentionBackend"
|
||||
)
|
||||
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
|
||||
# Placeholder for third-party/custom backends - must be registered before use
|
||||
# set to None to avoid alias with other backend, whose value is an empty string
|
||||
CUSTOM = None
|
||||
|
||||
def get_path(self, include_classname: bool = True) -> str:
|
||||
"""Get the class path for this backend (respects overrides).
|
||||
|
||||
Returns:
|
||||
The fully qualified class path string
|
||||
|
||||
Raises:
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
path = _ATTN_OVERRIDES.get(self, self.value)
|
||||
if not path:
|
||||
raise ValueError(
|
||||
f"Backend {self.name} must be registered before use. "
|
||||
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
|
||||
)
|
||||
if not include_classname:
|
||||
path = path.rsplit(".", 1)[0]
|
||||
return path
|
||||
|
||||
def get_class(self) -> "type[AttentionBackend]":
|
||||
"""Get the backend class (respects overrides).
|
||||
|
||||
Returns:
|
||||
The backend class
|
||||
|
||||
Raises:
|
||||
ImportError: If the backend class cannot be imported
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
return resolve_obj_by_qualname(self.get_path())
|
||||
|
||||
def is_overridden(self) -> bool:
|
||||
"""Check if this backend has been overridden.
|
||||
|
||||
Returns:
|
||||
True if the backend has a registered override
|
||||
"""
|
||||
return self in _ATTN_OVERRIDES
|
||||
|
||||
def clear_override(self) -> None:
|
||||
"""Clear any override for this backend, reverting to the default."""
|
||||
_ATTN_OVERRIDES.pop(self, None)
|
||||
|
||||
|
||||
class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"""Enumeration of all supported mamba attention backends.
|
||||
|
||||
The enum value is the default class path, but this can be overridden
|
||||
at runtime using register_backend().
|
||||
|
||||
To get the actual backend class (respecting overrides), use:
|
||||
backend.get_class()
|
||||
"""
|
||||
|
||||
MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend"
|
||||
MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend"
|
||||
SHORT_CONV = "vllm.v1.attention.backends.short_conv_attn.ShortConvAttentionBackend"
|
||||
LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend"
|
||||
GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend"
|
||||
# Placeholder for third-party/custom backends - must be registered before use
|
||||
# set to None to avoid alias with other backend, whose value is an empty string
|
||||
CUSTOM = None
|
||||
|
||||
def get_path(self, include_classname: bool = True) -> str:
|
||||
"""Get the class path for this backend (respects overrides).
|
||||
|
||||
Returns:
|
||||
The fully qualified class path string
|
||||
|
||||
Raises:
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
path = _MAMBA_ATTN_OVERRIDES.get(self, self.value)
|
||||
if not path:
|
||||
raise ValueError(
|
||||
f"Backend {self.name} must be registered before use. "
|
||||
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
|
||||
)
|
||||
if not include_classname:
|
||||
path = path.rsplit(".", 1)[0]
|
||||
return path
|
||||
|
||||
def get_class(self) -> "type[AttentionBackend]":
|
||||
"""Get the backend class (respects overrides).
|
||||
|
||||
Returns:
|
||||
The backend class
|
||||
|
||||
Raises:
|
||||
ImportError: If the backend class cannot be imported
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
return resolve_obj_by_qualname(self.get_path())
|
||||
|
||||
def is_overridden(self) -> bool:
|
||||
"""Check if this backend has been overridden.
|
||||
|
||||
Returns:
|
||||
True if the backend has a registered override
|
||||
"""
|
||||
return self in _MAMBA_ATTN_OVERRIDES
|
||||
|
||||
def clear_override(self) -> None:
|
||||
"""Clear any override for this backend, reverting to the default."""
|
||||
_MAMBA_ATTN_OVERRIDES.pop(self, None)
|
||||
|
||||
|
||||
MAMBA_TYPE_TO_BACKEND_MAP = {
|
||||
"mamba1": MambaAttentionBackendEnum.MAMBA1.name,
|
||||
"mamba2": MambaAttentionBackendEnum.MAMBA2.name,
|
||||
"short_conv": MambaAttentionBackendEnum.SHORT_CONV.name,
|
||||
"linear_attention": MambaAttentionBackendEnum.LINEAR.name,
|
||||
"gdn_attention": MambaAttentionBackendEnum.GDN_ATTN.name,
|
||||
"custom": MambaAttentionBackendEnum.CUSTOM.name,
|
||||
}
|
||||
|
||||
|
||||
_ATTN_OVERRIDES: dict[AttentionBackendEnum, str] = {}
|
||||
_MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {}
|
||||
|
||||
|
||||
def register_backend(
|
||||
backend: AttentionBackendEnum | MambaAttentionBackendEnum,
|
||||
class_path: str | None = None,
|
||||
is_mamba: bool = False,
|
||||
) -> Callable[[type], type]:
|
||||
"""Register or override a backend implementation.
|
||||
|
||||
Args:
|
||||
backend: The AttentionBackendEnum member to register
|
||||
class_path: Optional class path. If not provided and used as
|
||||
decorator, will be auto-generated from the class.
|
||||
|
||||
Returns:
|
||||
Decorator function if class_path is None, otherwise a no-op
|
||||
|
||||
Examples:
|
||||
# Override an existing attention backend
|
||||
@register_backend(AttentionBackendEnum.FLASH_ATTN)
|
||||
class MyCustomFlashAttn:
|
||||
...
|
||||
|
||||
# Override an existing mamba attention backend
|
||||
@register_backend(MambaAttentionBackendEnum.LINEAR, is_mamba=True)
|
||||
class MyCustomMambaAttn:
|
||||
...
|
||||
|
||||
# Register a custom third-party attention backend
|
||||
@register_backend(AttentionBackendEnum.CUSTOM)
|
||||
class MyCustomBackend:
|
||||
...
|
||||
|
||||
# Direct registration
|
||||
register_backend(
|
||||
AttentionBackendEnum.CUSTOM,
|
||||
"my.module.MyCustomBackend"
|
||||
)
|
||||
"""
|
||||
|
||||
def decorator(cls: type) -> type:
|
||||
if is_mamba:
|
||||
_MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
|
||||
else:
|
||||
_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
|
||||
return cls
|
||||
|
||||
if class_path is not None:
|
||||
if is_mamba:
|
||||
_MAMBA_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
|
||||
else:
|
||||
_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
|
||||
return lambda x: x
|
||||
|
||||
return decorator
|
||||
@@ -7,25 +7,25 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_prefills_and_extends,
|
||||
)
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.rocm_attn import (
|
||||
RocmAttentionBackend,
|
||||
|
||||
@@ -7,17 +7,6 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@@ -25,12 +14,25 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode,
|
||||
)
|
||||
from vllm.v1.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
|
||||
@@ -9,20 +9,20 @@ from typing import ClassVar, Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -7,17 +7,6 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
@@ -28,11 +17,22 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -29,16 +29,16 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice
|
||||
|
||||
|
||||
0
vllm/v1/attention/ops/__init__.py
Normal file
0
vllm/v1/attention/ops/__init__.py
Normal file
459
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
Normal file
459
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Authors:
|
||||
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_paged_attention_2d(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
out_scale_inv,
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
num_queries_per_kv_padded: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
PHYSICAL_BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
x: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.int64, # int
|
||||
stride_k_cache_4: tl.int64, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.int64, # int
|
||||
filter_by_query_len: tl.constexpr, # bool
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
USE_FP8: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
seq_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
|
||||
if filter_by_query_len:
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
|
||||
if cur_batch_query_len > 1:
|
||||
return
|
||||
else:
|
||||
cur_batch_in_all_start_index = seq_idx
|
||||
|
||||
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
|
||||
0, num_queries_per_kv_padded
|
||||
)
|
||||
|
||||
query_offset = (
|
||||
cur_batch_in_all_start_index * query_stride_0
|
||||
+ query_head_idx[:, None] * query_stride_1
|
||||
)
|
||||
|
||||
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
|
||||
head_mask = head_mask & (query_head_idx < num_query_heads)
|
||||
|
||||
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||
|
||||
# Q : (num_queries_per_kv, HEAD_SIZE,)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||
mask=dim_mask[None, :] & head_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
if not USE_SINKS:
|
||||
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
||||
L = tl.zeros([num_queries_per_kv_padded], dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_head_idx,
|
||||
mask=head_mask,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
L = tl.where(float("-inf") < M, 1.0, 0.0)
|
||||
|
||||
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(
|
||||
alibi_slopes_ptr + query_head_idx, mask=head_mask, other=0.0
|
||||
)
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
# iterate through tiles
|
||||
for j in range(0, num_blocks):
|
||||
start_n = j * BLOCK_SIZE
|
||||
# Calculate the logical location within a non-standard physical block,
|
||||
# such as 544 in Qwen/Qwen3-Next-80B-A3B-Thinking.
|
||||
# Supports non-contiguous mapping
|
||||
# from logical blocks to physical blocks
|
||||
abs_token_idx = start_n + offs_n
|
||||
l_block_idx = abs_token_idx // PHYSICAL_BLOCK_SIZE
|
||||
# Vectorized loading of physical block IDs
|
||||
p_block_idx = tl.load(block_tables_ptr + block_table_offset + l_block_idx)
|
||||
internal_offsets = abs_token_idx % PHYSICAL_BLOCK_SIZE
|
||||
|
||||
# 5D addressing logic of K
|
||||
k_offset = (
|
||||
p_block_idx[None, :] * stride_k_cache_0
|
||||
+ kv_head_idx * stride_k_cache_1
|
||||
+ (offs_d[:, None] // x) * stride_k_cache_2
|
||||
+ internal_offsets[None, :] * stride_k_cache_3
|
||||
+ (offs_d[:, None] % x) * stride_k_cache_4
|
||||
)
|
||||
|
||||
# 4D addressing logic of V (Slot is innermost)
|
||||
v_offset = (
|
||||
p_block_idx[:, None] * stride_v_cache_0
|
||||
+ kv_head_idx * stride_v_cache_1
|
||||
+ offs_d[None, :] * stride_v_cache_2
|
||||
+ internal_offsets[:, None] * stride_v_cache_3
|
||||
)
|
||||
|
||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||
K_load = tl.load(
|
||||
key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None],
|
||||
other=0.0,
|
||||
eviction_policy="evict_last",
|
||||
)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(
|
||||
value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :],
|
||||
other=0.0,
|
||||
eviction_policy="evict_last",
|
||||
)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
|
||||
seq_mask = seq_offset[None, :] < boundary
|
||||
|
||||
# First calculate the dot, then apply the mask.
|
||||
qk = scale * tl.dot(Q, K)
|
||||
S = tl.where(head_mask[:, None] & seq_mask, qk, float("-inf"))
|
||||
|
||||
context_len = seq_len - 1
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, -10000)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (num_queries_per_kv,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# P : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
p = tl.exp(S - m_j[:, None])
|
||||
p = tl.where(m_j[:, None] == float("-inf"), 0.0, p)
|
||||
|
||||
# l_j : (num_queries_per_kv,)
|
||||
l_j = tl.sum(p, axis=1)
|
||||
|
||||
# alpha : (num_queries_per_kv, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
alpha = tl.where(float("-inf") == M, 0.0, alpha)
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc += tl.dot(p.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
acc = acc / (L[:, None] + 1e-10)
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
|
||||
output_offset = (
|
||||
cur_batch_in_all_start_index * output_stride_0
|
||||
+ query_head_idx * output_stride_1
|
||||
)
|
||||
|
||||
tl.store(
|
||||
output_ptr + output_offset[:, None] + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||
acc,
|
||||
mask=dim_mask[None, :] & head_mask[:, None],
|
||||
)
|
||||
|
||||
|
||||
def chunked_prefill_paged_decode(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_table,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
output_scale=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
is_block_table_ptr: bool = False,
|
||||
):
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (query.shape[2] ** 0.5)
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if max_query_len > 1:
|
||||
context_attention_fwd(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
o=output,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
b_loc=block_table,
|
||||
b_start_loc=query_start_loc,
|
||||
b_seq_len=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
max_input_len=max_query_len,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sliding_window=sliding_window,
|
||||
sm_scale=sm_scale,
|
||||
skip_decode=True,
|
||||
fp8_out_scale=output_scale,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs = len(seq_lens)
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = key.shape[1]
|
||||
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||
head_size = query.shape[2]
|
||||
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
target_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported FP8 kv_cache_dtype {kv_cache_dtype}: "
|
||||
f"should be one of 'fp8', 'fp8_e4m3', 'fp8_e5m2'."
|
||||
)
|
||||
|
||||
key_cache = key_cache.view(target_dtype)
|
||||
value_cache = value_cache.view(target_dtype)
|
||||
|
||||
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16)
|
||||
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
query.dtype,
|
||||
head_size,
|
||||
block_size,
|
||||
num_queries_per_kv,
|
||||
max_seq_len,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
alibi_slopes,
|
||||
sinks,
|
||||
)
|
||||
# Triton is only forced when encountering a non-standard block
|
||||
# like Qwen3 with a size of 544.
|
||||
# 1. Check if block_size is a power of 2 (16, 32, 64...)
|
||||
# 2. If it's a power of 2, we trust the vLLM's native use_custom decision.
|
||||
# 3. If it's not a power of 2 (such as Qwen3's 544),
|
||||
# then our Triton path is forced.
|
||||
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
||||
if not is_pow2:
|
||||
use_custom = False
|
||||
|
||||
if use_custom:
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
max_num_partitions = (
|
||||
max_seq_len + _PARTITION_SIZE_ROCM - 1
|
||||
) // _PARTITION_SIZE_ROCM
|
||||
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||
total_num_seq = block_table.shape[0]
|
||||
tmp_output = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions, head_size),
|
||||
dtype=query.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale=sm_scale,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
block_size=block_size,
|
||||
max_seq_len=max_seq_len,
|
||||
alibi_slopes=alibi_slopes,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
fp8_out_scale=output_scale,
|
||||
)
|
||||
else:
|
||||
real_block_size = value_cache.shape[3]
|
||||
# The standard model directly uses the original block_size.
|
||||
# Non-standard 544 uses 32 to accommodate integer division logic.
|
||||
TRITON_BLOCK_SIZE = block_size if is_pow2 else 32
|
||||
if is_block_table_ptr:
|
||||
# Using the physical base address of tensors
|
||||
kv_element_size = key_cache.element_size()
|
||||
block_byte_stride = key_cache.stride(0) * kv_element_size
|
||||
# Get the starting physical address of the KV Cache
|
||||
base_addr = key_cache.data_ptr()
|
||||
|
||||
# Normalization: Directly calculate the block offset
|
||||
# of the pointer relative to the base address
|
||||
processed_block_table = ((block_table - base_addr) // block_byte_stride).to(
|
||||
torch.int32
|
||||
)
|
||||
else:
|
||||
processed_block_table = block_table.to(torch.int32)
|
||||
|
||||
kernel_paged_attention_2d[
|
||||
(
|
||||
num_seqs,
|
||||
num_kv_heads,
|
||||
)
|
||||
](
|
||||
output_ptr=output,
|
||||
query_ptr=query,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=processed_block_table,
|
||||
seq_lens_ptr=seq_lens,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
scale=sm_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out_scale_inv=1.0 / output_scale if output_scale is not None else 1.0,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||
block_table_stride=processed_block_table.stride(0),
|
||||
query_stride_0=query.stride(0),
|
||||
query_stride_1=query.stride(1),
|
||||
output_stride_0=output.stride(0),
|
||||
output_stride_1=output.stride(1),
|
||||
BLOCK_SIZE=TRITON_BLOCK_SIZE,
|
||||
PHYSICAL_BLOCK_SIZE=real_block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
x=key_cache.shape[4],
|
||||
stride_k_cache_0=key_cache.stride(0),
|
||||
stride_k_cache_1=key_cache.stride(1),
|
||||
stride_k_cache_2=key_cache.stride(2),
|
||||
stride_k_cache_3=key_cache.stride(3),
|
||||
stride_k_cache_4=key_cache.stride(4),
|
||||
stride_v_cache_0=value_cache.stride(0),
|
||||
stride_v_cache_1=value_cache.stride(1),
|
||||
stride_v_cache_2=value_cache.stride(2),
|
||||
stride_v_cache_3=value_cache.stride(3),
|
||||
filter_by_query_len=True,
|
||||
query_start_len_ptr=query_start_loc,
|
||||
USE_SINKS=sinks is not None,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
469
vllm/v1/attention/ops/common.py
Normal file
469
vllm/v1/attention/ops/common.py
Normal file
@@ -0,0 +1,469 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _correct_attn_cp_out_kernel(
|
||||
outputs_ptr,
|
||||
new_output_ptr,
|
||||
lses_ptr,
|
||||
vlse_ptr,
|
||||
outputs_stride_B,
|
||||
outputs_stride_H,
|
||||
outputs_stride_D,
|
||||
lses_stride_N,
|
||||
lses_stride_B,
|
||||
lses_stride_H,
|
||||
lse_idx,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
N_ROUNDED: tl.constexpr,
|
||||
IS_BASE_E: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Apply the all-gathered lses to correct each local rank's attention
|
||||
output. we still need perform a cross-rank reduction to obtain the
|
||||
final attention output.
|
||||
|
||||
Args:
|
||||
outputs_ptr (triton.PointerType):
|
||||
Pointer to input tensor of shape [ B, H, D ]
|
||||
lses_ptr (triton.PointerType):
|
||||
Pointer to input tensor of shape [ N, B, H ]
|
||||
new_output_ptr (triton.PointerType):
|
||||
Pointer to output tensor of shape [ B, H, D ]
|
||||
vlse_ptr (triton.PointerType):
|
||||
Pointer to output tensor of shape [ B, H ]
|
||||
"""
|
||||
batch_idx = tl.program_id(axis=0).to(tl.int64)
|
||||
head_idx = tl.program_id(axis=1).to(tl.int64)
|
||||
d_offsets = tl.arange(0, HEAD_DIM)
|
||||
num_n_offsets = tl.arange(0, N_ROUNDED)
|
||||
|
||||
# shape = [N]
|
||||
lse_offsets = (
|
||||
num_n_offsets * lses_stride_N
|
||||
+ batch_idx * lses_stride_B
|
||||
+ head_idx * lses_stride_H
|
||||
)
|
||||
|
||||
# calc final lse
|
||||
lse = tl.load(lses_ptr + lse_offsets)
|
||||
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
|
||||
lse_max = tl.max(lse, axis=0)
|
||||
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
|
||||
lse -= lse_max
|
||||
if IS_BASE_E:
|
||||
lse_exp = tl.exp(lse)
|
||||
lse_acc = tl.sum(lse_exp, axis=0)
|
||||
lse = tl.log(lse_acc)
|
||||
else:
|
||||
lse_exp = tl.exp2(lse)
|
||||
lse_acc = tl.sum(lse_exp, axis=0)
|
||||
lse = tl.log2(lse_acc)
|
||||
lse += lse_max
|
||||
|
||||
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
|
||||
tl.store(vlse_ptr + lse_offsets, lse)
|
||||
|
||||
# shape = [D]
|
||||
output_offsets = (
|
||||
batch_idx * outputs_stride_B
|
||||
+ head_idx * outputs_stride_H
|
||||
+ d_offsets * outputs_stride_D
|
||||
)
|
||||
|
||||
# correct output
|
||||
lse_offset = (
|
||||
lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H
|
||||
)
|
||||
lse_tmp = tl.load(lses_ptr + lse_offset)
|
||||
lse_finally = lse_tmp - lse
|
||||
lse_finally = tl.where(
|
||||
(lse_finally != lse_finally) | (lse_finally == float("inf")),
|
||||
-float("inf"),
|
||||
lse_finally,
|
||||
)
|
||||
factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally)
|
||||
output = tl.load(outputs_ptr + output_offsets)
|
||||
output = output * factor
|
||||
|
||||
tl.store(new_output_ptr + output_offsets, output)
|
||||
|
||||
|
||||
class CPTritonContext:
|
||||
"""The CPTritonContext is used to avoid recompilation of the Triton JIT."""
|
||||
|
||||
def __init__(self):
|
||||
self.inner_kernel = None
|
||||
|
||||
def call_kernel(self, kernel, grid, *regular_args, **const_args):
|
||||
if self.inner_kernel is None:
|
||||
self.inner_kernel = kernel[grid](*regular_args, **const_args)
|
||||
else:
|
||||
self.inner_kernel[grid](*regular_args)
|
||||
|
||||
|
||||
def correct_attn_out(
|
||||
out: torch.Tensor,
|
||||
lses: torch.Tensor,
|
||||
cp_rank: int,
|
||||
ctx: CPTritonContext,
|
||||
is_lse_base_on_e: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Correct the attention output using the all-gathered lses.
|
||||
|
||||
Args:
|
||||
out: Tensor of shape [ B, H, D ]
|
||||
lses: Tensor of shape [ N, B, H ]
|
||||
cp_rank: Current rank in the context-parallel group
|
||||
ctx: Triton context to avoid recompilation
|
||||
|
||||
Returns:
|
||||
Tuple of (out, lse) with corrected attention and final log-sum-exp.
|
||||
"""
|
||||
if ctx is None:
|
||||
ctx = CPTritonContext()
|
||||
|
||||
# --- Normalize to 3D views ---
|
||||
if out.ndim == 4 and out.shape[1] == 1:
|
||||
out = out.squeeze(1)
|
||||
assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}"
|
||||
|
||||
if lses.ndim == 4 and lses.shape[-1] == 1:
|
||||
lses = lses.squeeze(-1)
|
||||
if lses.ndim == 4 and lses.shape[1] == 1:
|
||||
lses = lses.squeeze(1)
|
||||
assert lses.ndim == 3, (
|
||||
f"expected lses [N,B,H] (optionally with a 1-sized extra dim), "
|
||||
f"got {tuple(lses.shape)}"
|
||||
)
|
||||
|
||||
B, H, D = out.shape
|
||||
N = lses.shape[0]
|
||||
|
||||
# Strides after we normalized shapes to 3-D views. The kernel computes
|
||||
# offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must
|
||||
# have the same B/H stride layout as a slice of `lses`.
|
||||
o_sB, o_sH, o_sD = out.stride()
|
||||
l_sN, l_sB, l_sH = lses.stride()
|
||||
|
||||
# Allocate LSE with the same B/H strides as `lses` so writes land correctly
|
||||
# even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze).
|
||||
lse = torch.empty_strided(
|
||||
(B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype
|
||||
)
|
||||
|
||||
# Kernel launch config
|
||||
grid = (B, H, 1)
|
||||
|
||||
regular_args = (
|
||||
out,
|
||||
out,
|
||||
lses,
|
||||
lse,
|
||||
o_sB,
|
||||
o_sH,
|
||||
o_sD,
|
||||
l_sN,
|
||||
l_sB,
|
||||
l_sH,
|
||||
cp_rank,
|
||||
)
|
||||
const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e}
|
||||
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
|
||||
return out, lse
|
||||
|
||||
|
||||
def _cp_lse_common(
|
||||
cp_attn_out: torch.Tensor,
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext | None = None,
|
||||
is_lse_base_on_e=True,
|
||||
):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
cp_attn_lse: [ B, H ]
|
||||
"""
|
||||
if cp_group.world_size == 1:
|
||||
return cp_attn_out
|
||||
|
||||
if ctx is None:
|
||||
ctx = CPTritonContext()
|
||||
|
||||
lses = torch.empty(
|
||||
(cp_group.world_size,) + cp_attn_lse.shape,
|
||||
dtype=cp_attn_lse.dtype,
|
||||
device=cp_attn_lse.device,
|
||||
)
|
||||
|
||||
cp_attn_lse = cp_attn_lse.contiguous()
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||
out, lse = correct_attn_out(
|
||||
cp_attn_out,
|
||||
lses,
|
||||
cp_group.rank_in_group,
|
||||
ctx,
|
||||
is_lse_base_on_e=is_lse_base_on_e,
|
||||
)
|
||||
return out, lse
|
||||
|
||||
|
||||
def cp_lse_ag_out_rs(
|
||||
cp_attn_out: torch.Tensor,
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext | None = None,
|
||||
return_lse: bool = False,
|
||||
is_lse_base_on_e=True,
|
||||
):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
cp_attn_lse: [ B, H ]
|
||||
"""
|
||||
out, lse = _cp_lse_common(
|
||||
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
|
||||
)
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
|
||||
if return_lse:
|
||||
cp_num_heads = lse.shape[1] // cp_group.world_size
|
||||
cp_rank = cp_group.rank_in_group
|
||||
lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)]
|
||||
return out, lse
|
||||
return out
|
||||
|
||||
|
||||
def cp_lse_ag_out_ar(
|
||||
cp_attn_out: torch.Tensor,
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext | None = None,
|
||||
return_lse: bool = False,
|
||||
is_lse_base_on_e=True,
|
||||
):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
cp_attn_lse: [ B, H ]
|
||||
"""
|
||||
out, lse = _cp_lse_common(
|
||||
cp_attn_out, cp_attn_lse, cp_group, ctx=ctx, is_lse_base_on_e=is_lse_base_on_e
|
||||
)
|
||||
out = cp_group.all_reduce(out)
|
||||
|
||||
if return_lse:
|
||||
return out, lse
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _pack_seq_kernel(
|
||||
x_ptr, # [N, D]
|
||||
out_ptr, # [B, Lmax, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
N: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
PAD_VALUE: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr, # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# Compute start index and sequence length from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
|
||||
# compute input row indices for valid (b, t)
|
||||
in_row = in_start + off_t
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# Pointers
|
||||
# x_ptr: row-major [N, D]
|
||||
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [B, Lmax, D]
|
||||
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]
|
||||
|
||||
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
|
||||
d_mask = off_d[None, :] < D
|
||||
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
|
||||
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
|
||||
|
||||
# Load & write only where within seq_len
|
||||
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def pack_seq_triton(
|
||||
x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pad_value: float = -float("inf"),
|
||||
block_t: int = 64,
|
||||
block_d: int = 64,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pack sequences of different lengths into a batched tensor.
|
||||
|
||||
Args:
|
||||
x: [N, ...] - input tensor where N is total number of tokens
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
pad_value: value to use for padding
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
packed: [B, Lmax, ...] - packed tensor
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (N, -1)
|
||||
original_shape = x.shape
|
||||
if len(original_shape) > 2:
|
||||
N = original_shape[0]
|
||||
x_reshaped = x.reshape(N, -1)
|
||||
D = x_reshaped.shape[1]
|
||||
else:
|
||||
N, D = x.shape
|
||||
x_reshaped = x
|
||||
|
||||
B = lengths.numel()
|
||||
Lmax = int(lengths.max().item())
|
||||
|
||||
# Starts are computed inside the kernel from lengths
|
||||
|
||||
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_pack_seq_kernel[grid](
|
||||
x_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
N,
|
||||
D,
|
||||
Lmax,
|
||||
PAD_VALUE=float(pad_value),
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 2:
|
||||
output_shape = (B, Lmax) + original_shape[1:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _unpack_seq_triton_kernel(
|
||||
packed_ptr, # [B, Lmax, D]
|
||||
out_ptr, # [N, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
B: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr, # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# bounds: compute start from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# compute output row indices for valid (b, t)
|
||||
out_row = in_start + off_t
|
||||
|
||||
# Pointers
|
||||
# packed_ptr: row-major [B, Lmax, D]
|
||||
packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [N, D]
|
||||
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# Load from packed tensor and store to output
|
||||
d_mask = off_d[None, :] < D
|
||||
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def unpack_seq_triton(
|
||||
packed_tensor: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
block_t: int = 64,
|
||||
block_d: int = 64,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Unpack a packed decode query tensor back to the original format.
|
||||
Efficient Triton implementation.
|
||||
|
||||
Args:
|
||||
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
unpacked_tensor: [N, ...] where N = sum(lengths)
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
|
||||
original_shape = packed_tensor.shape
|
||||
if len(original_shape) > 3:
|
||||
B, Lmax = original_shape[:2]
|
||||
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
|
||||
D = packed_reshaped.shape[2]
|
||||
else:
|
||||
B, Lmax, D = packed_tensor.shape
|
||||
packed_reshaped = packed_tensor
|
||||
|
||||
# Calculate total number of elements
|
||||
N = int(lengths.sum().item())
|
||||
|
||||
out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_unpack_seq_triton_kernel[grid](
|
||||
packed_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
B,
|
||||
Lmax,
|
||||
D,
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 3:
|
||||
output_shape = (N,) + original_shape[2:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
254
vllm/v1/attention/ops/flashmla.py
Normal file
254
vllm/v1/attention/ops/flashmla.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_C # noqa: F401
|
||||
|
||||
_flashmla_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_extension_C # noqa: F401
|
||||
|
||||
_flashmla_extension_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
|
||||
|
||||
def _is_flashmla_available() -> tuple[bool, str | None]:
|
||||
if not _flashmla_C_AVAILABLE:
|
||||
return (
|
||||
False,
|
||||
"vllm._flashmla_C is not available, likely was not "
|
||||
"compiled due to insufficient nvcc version or a supported arch "
|
||||
"was not in the list of target arches to compile for.",
|
||||
)
|
||||
if not _flashmla_extension_C_AVAILABLE:
|
||||
return (
|
||||
False,
|
||||
"vllm._flashmla_extension_C is not available, likely "
|
||||
"was not compiled due to a build error.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def is_flashmla_dense_supported() -> tuple[bool, str | None]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
is_availble, maybe_reason = _is_flashmla_available()
|
||||
if not is_availble:
|
||||
return False, maybe_reason
|
||||
if not current_platform.is_device_capability_family(90):
|
||||
return False, "FlashMLA Dense is only supported on Hopper devices."
|
||||
return True, None
|
||||
|
||||
|
||||
def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
is_availble, maybe_reason = _is_flashmla_available()
|
||||
if not is_availble:
|
||||
return False, maybe_reason
|
||||
if not (
|
||||
current_platform.is_device_capability_family(90)
|
||||
or current_platform.is_device_capability_family(100)
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_q_tokens_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
num_heads_q: int | None = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
topk: int | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
- cache_seqlens: (batch_size), dtype torch.int32.
|
||||
- num_q_tokens_per_head_k:
|
||||
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
|
||||
- num_heads_k: The number of k heads.
|
||||
- num_heads_q:
|
||||
The number of q heads.
|
||||
This argument is optional when sparse attention is not enabled
|
||||
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
|
||||
- topk: If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array
|
||||
passed to `flash_mla_with_kvcache_sm90` will be attended to.
|
||||
|
||||
Returns:
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
- num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
if is_fp8_kvcache and topk is None:
|
||||
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
|
||||
cache_seqlens,
|
||||
num_q_tokens_per_head_k,
|
||||
num_heads_k,
|
||||
)
|
||||
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
||||
cache_seqlens,
|
||||
num_q_tokens_per_head_k,
|
||||
num_heads_k,
|
||||
num_heads_q,
|
||||
is_fp8_kvcache,
|
||||
topk,
|
||||
)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: float | None = None,
|
||||
causal: bool = False,
|
||||
descale_q: torch.Tensor | None = None,
|
||||
descale_k: torch.Tensor | None = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
- cache_seqlens: (batch_size), torch.int32.
|
||||
- head_dim_v: Head dimension of v.
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
|
||||
returned by get_mla_metadata.
|
||||
- num_splits:
|
||||
(batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
- softmax_scale: float.
|
||||
The scale of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
- causal: bool. Whether to apply causal attention mask.
|
||||
- descale_q: (batch_size),
|
||||
torch.float32. Descaling factors for Q, used for fp8 quantization.
|
||||
- descale_k: (batch_size),
|
||||
torch.float32. Descaling factors for K, used for fp8 quantization.
|
||||
- is_fp8_kvcache: bool.
|
||||
Whether the k_cache and v_cache are in fp8 format.
|
||||
For the format of FP8 KV cache, please refer to README.md
|
||||
- indices: (batch_size, seq_len_q, topk), torch.int32.
|
||||
If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array will be attended to.
|
||||
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
||||
For details about how to set up `indices`, please refer to README.md.
|
||||
|
||||
Returns:
|
||||
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
if indices is not None:
|
||||
# NOTE (zyongye): sparse attention is also causal
|
||||
# since it only attend to the tokens before
|
||||
# but here `causal` should not be specified
|
||||
assert not causal, "causal must be `false` if sparse attention is enabled."
|
||||
assert (descale_q is None) == (descale_k is None), (
|
||||
"descale_q and descale_k should be both None or both not None"
|
||||
)
|
||||
|
||||
if indices is None and q.element_size() == 1:
|
||||
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
|
||||
q,
|
||||
k_cache,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
descale_q,
|
||||
descale_k,
|
||||
)
|
||||
else:
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
is_fp8_kvcache,
|
||||
indices,
|
||||
)
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
def flash_mla_sparse_prefill(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
- q: [s_q, h_q, d_qk], bfloat16
|
||||
- kv: [s_kv, h_kv, d_qk], bfloat16
|
||||
- indices: [s_q, h_kv, topk], int32.
|
||||
Invalid indices should be set to -1 or numbers >= s_kv
|
||||
- sm_scale: float
|
||||
- d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
- (output, max_logits, lse)
|
||||
About the definition of output,
|
||||
max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v)
|
||||
return results
|
||||
|
||||
|
||||
#
|
||||
# TODO: Add fake functions
|
||||
#
|
||||
# @register_fake("_flashmla_C::get_mla_metadata")
|
||||
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
||||
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
47
vllm/v1/attention/ops/merge_attn_states.py
Normal file
47
vllm/v1/attention/ops/merge_attn_states.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
||||
# does not support FP8 dtype, fallback to use Triton kernel.
|
||||
def supported_dtypes(o: torch.Tensor) -> bool:
|
||||
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
|
||||
# kernel load/store 128b(16 bytes) per memory issue within
|
||||
# thread. Namely, the headsize(headdim) must be multiple of
|
||||
# pack_size (float32 -> 4, half/bfloat16 -> 8).
|
||||
def supported_headdim(o: torch.Tensor) -> bool:
|
||||
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
if o.dtype == torch.float32:
|
||||
return headdim % 4 == 0
|
||||
return headdim % 8 == 0
|
||||
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and supported_dtypes(output)
|
||||
and supported_headdim(output)
|
||||
):
|
||||
from vllm._custom_ops import merge_attn_states
|
||||
|
||||
return merge_attn_states(
|
||||
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
|
||||
)
|
||||
else:
|
||||
from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
|
||||
return merge_attn_states(
|
||||
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
|
||||
)
|
||||
55
vllm/v1/attention/ops/paged_attn.py
Normal file
55
vllm/v1/attention/ops/paged_attn.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops
|
||||
|
||||
ops = _custom_ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
ops = ipex_ops
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
130
vllm/v1/attention/ops/pallas_kv_cache_update.py
Normal file
130
vllm/v1/attention/ops/pallas_kv_cache_update.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
def _kv_cache_update_kernel(
|
||||
# Prefetch
|
||||
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
|
||||
# new_kv_start, slice_len)
|
||||
num_slices_ref, # [1]
|
||||
# Input
|
||||
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
||||
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
||||
# head_dim]
|
||||
# Output
|
||||
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||
# Scratch
|
||||
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
|
||||
# head_dim]
|
||||
sem,
|
||||
):
|
||||
async_copies = []
|
||||
block_idx = pl.program_id(0)
|
||||
num_slices_per_block = scratch.shape[0]
|
||||
|
||||
# Copy from new_kv_hbm_ref to scratch
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
new_kv_start = jax.lax.select(
|
||||
offset_i < num_slices_ref[0], slices_ref[1, offset_i], 0
|
||||
)
|
||||
length = jax.lax.select(
|
||||
offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0
|
||||
)
|
||||
async_copy = pltpu.make_async_copy(
|
||||
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
sem,
|
||||
)
|
||||
async_copy.start()
|
||||
async_copies.append(async_copy)
|
||||
|
||||
for async_copy in async_copies:
|
||||
async_copy.wait()
|
||||
|
||||
# Copy from scratch to kv_cache_hbm_ref
|
||||
async_copies.clear()
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
kv_cache_start = jax.lax.select(
|
||||
offset_i < num_slices_ref[0], slices_ref[0, offset_i], 0
|
||||
)
|
||||
length = jax.lax.select(
|
||||
offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0
|
||||
)
|
||||
async_copy = pltpu.make_async_copy(
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
|
||||
sem,
|
||||
)
|
||||
async_copy.start()
|
||||
async_copies.append(async_copy)
|
||||
for async_copy in async_copies:
|
||||
async_copy.wait()
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=["page_size", "num_slices_per_block"],
|
||||
)
|
||||
def kv_cache_update(
|
||||
# [total_num_token, num_combined_kv_heads, head_dim]
|
||||
new_kv: jax.Array,
|
||||
# [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
||||
slices: jax.Array,
|
||||
# [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||
kv_cache: jax.Array,
|
||||
# [1]
|
||||
num_kv_update_slices: jax.Array,
|
||||
*,
|
||||
page_size: int = 32,
|
||||
num_slices_per_block: int = 8,
|
||||
):
|
||||
_, num_combined_kv_heads, head_dim = new_kv.shape
|
||||
assert kv_cache.shape[1] == num_combined_kv_heads
|
||||
assert kv_cache.shape[2] == head_dim
|
||||
assert head_dim % 128 == 0
|
||||
# TODO: Add dynamic check to make sure that the all the slice lengths are
|
||||
# smaller or equal to page_size
|
||||
|
||||
in_specs = [
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
]
|
||||
|
||||
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
|
||||
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
|
||||
|
||||
scalar_prefetches = [slices, num_kv_update_slices]
|
||||
scratch = pltpu.VMEM(
|
||||
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
|
||||
new_kv.dtype,
|
||||
)
|
||||
|
||||
scratch_shapes = [
|
||||
scratch,
|
||||
pltpu.SemaphoreType.DMA,
|
||||
]
|
||||
|
||||
kernel = pl.pallas_call(
|
||||
_kv_cache_update_kernel,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=len(scalar_prefetches),
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block),),
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
out_shape=out_shape,
|
||||
input_output_aliases={len(scalar_prefetches) + 1: 0},
|
||||
)
|
||||
|
||||
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
|
||||
862
vllm/v1/attention/ops/prefix_prefill.py
Normal file
862
vllm/v1/attention/ops/prefix_prefill.py
Normal file
@@ -0,0 +1,862 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
||||
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
# Static kernels parameters
|
||||
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||
NUM_WARPS = 4 if current_platform.is_rocm() else 8
|
||||
|
||||
# To check compatibility
|
||||
IS_TURING = current_platform.get_device_capability() == (7, 5)
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
# Here's an example autotuner config for this kernel. This config does provide
|
||||
# a performance improvement, but dramatically increases first call latency in
|
||||
# triton 3.2. Because of this tradeoff, it's currently commented out.
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
|
||||
# "num_unroll_cache": 4, \
|
||||
# "num_unroll_request": 1 } | \
|
||||
# ({"kpack": 2, "waves_per_eu": 2} \
|
||||
# if current_platform.is_rocm() else {}), \
|
||||
# num_warps=4, \
|
||||
# num_stages=1)
|
||||
# ],
|
||||
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
|
||||
# )
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
sink_ptr,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
out_scale_inv,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
x: tl.constexpr,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl: tl.constexpr,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: tl.constexpr,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PHYSICAL_BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SLIDING_WINDOW: tl.constexpr,
|
||||
num_unroll_cache: tl.constexpr,
|
||||
num_unroll_request: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
USE_SINKS: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
MAX_Q_LEN: tl.constexpr = 0,
|
||||
MAX_CTX_LEN: tl.constexpr = 0,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
|
||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||
|
||||
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||
return
|
||||
|
||||
# start position inside of the query
|
||||
# generally, N goes over kv, while M goes over query_len
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
# [BLOCK_SIZE]; starts at 0
|
||||
offs_bs_n = tl.arange(0, BLOCK_SIZE)
|
||||
# [N]; starts at 0
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# [D]; starts at 0
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
# [M]; starts at current position in query
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# [M,D]
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :] * stride_qd
|
||||
)
|
||||
|
||||
dim_mask = tl.where(tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(
|
||||
tl.int1
|
||||
) # [D]
|
||||
|
||||
q = tl.load(
|
||||
Q + off_q,
|
||||
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len),
|
||||
other=0.0,
|
||||
) # [M,D]
|
||||
|
||||
# initialize pointer to m and l
|
||||
if not USE_SINKS:
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
else:
|
||||
m_i = tl.load(
|
||||
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
|
||||
mask=(offs_m < cur_batch_query_len),
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
l_i = tl.where(m_i > float("-inf"), 1.0, 0.0)
|
||||
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
|
||||
|
||||
# compute query against context (no causal mask here)
|
||||
for start_n in tl.range(
|
||||
0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache
|
||||
):
|
||||
# Under a block size of 544 (Qwen/Qwen3-Next-80B-A3B-Thinking),
|
||||
# replace one physical block every 17 32-Tile blocks
|
||||
# Calculate the logical block index of each of the 32 tokens
|
||||
# in the current Tile (handling cross-block cases).
|
||||
token_indices = start_n + offs_bs_n
|
||||
bn_logical_indices = token_indices // PHYSICAL_BLOCK_SIZE
|
||||
|
||||
# 2. Vectorized loading of physical block IDs from B_Loc
|
||||
bn = tl.load(
|
||||
B_Loc + cur_batch * stride_b_loc_b + bn_logical_indices * stride_b_loc_s
|
||||
).to(tl.int64)
|
||||
|
||||
# 3. Calculate the exact offset of
|
||||
# each token within its physical block.
|
||||
internal_offsets = token_indices % PHYSICAL_BLOCK_SIZE
|
||||
|
||||
# Addressing of K (5D)
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs
|
||||
+ cur_kv_head * stride_k_cache_h
|
||||
+ (offs_d[:, None] // x) * stride_k_cache_d
|
||||
+ internal_offsets[None, :] * stride_k_cache_bl
|
||||
+ (offs_d[:, None] % x) * stride_k_cache_x
|
||||
)
|
||||
|
||||
# Addressing of V (4D)
|
||||
off_v = (
|
||||
bn[:, None] * stride_v_cache_bs
|
||||
+ cur_kv_head * stride_v_cache_h
|
||||
+ offs_d[None, :] * stride_v_cache_d
|
||||
+ internal_offsets[:, None] * stride_v_cache_bl
|
||||
)
|
||||
|
||||
if (
|
||||
start_n + BLOCK_SIZE > cur_batch_ctx_len
|
||||
or BLOCK_DMODEL != BLOCK_DMODEL_PADDED
|
||||
):
|
||||
k_load = tl.load(
|
||||
K_cache + off_k,
|
||||
mask=dim_mask[:, None]
|
||||
& ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0,
|
||||
) # [D,N]
|
||||
else:
|
||||
k_load = tl.load(K_cache + off_k)
|
||||
|
||||
if k_load.dtype.is_fp8():
|
||||
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
# qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
|
||||
qk = sm_scale * tl.dot(q, k, input_precision=IN_PRECISION)
|
||||
qk = tl.where(
|
||||
(start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
|
||||
)
|
||||
# qk *= sm_scale
|
||||
if SLIDING_WINDOW > 0:
|
||||
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
|
||||
# Q entries in sequence
|
||||
# (start_n + offs_bs_n[None, :]) are the positions of
|
||||
# KV entries in sequence
|
||||
# So the condition makes sure each entry in Q only attends
|
||||
# to KV entries not more than SLIDING_WINDOW away.
|
||||
#
|
||||
# We can't use -inf here, because the
|
||||
# sliding window may lead to the entire row being masked.
|
||||
# This then makes m_ij contain -inf, which causes NaNs in
|
||||
# exp().
|
||||
qk = tl.where(
|
||||
(cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :])
|
||||
< SLIDING_WINDOW,
|
||||
qk,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
p = tl.where(m_ij[:, None] == float("-inf"), 0.0, p)
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
alpha = tl.where(m_i == float("-inf"), 0.0, alpha)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
if (
|
||||
start_n + BLOCK_SIZE > cur_batch_ctx_len
|
||||
or BLOCK_DMODEL != BLOCK_DMODEL_PADDED
|
||||
):
|
||||
v_load = tl.load(
|
||||
V_cache + off_v,
|
||||
mask=dim_mask[None, :]
|
||||
& ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0,
|
||||
) # [N,D]
|
||||
else:
|
||||
v_load = tl.load(V_cache + off_v)
|
||||
|
||||
if v_load.dtype.is_fp8():
|
||||
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||
else:
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# # update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
off_k = (
|
||||
offs_n[None, :] * stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_d[:, None] * stride_kd
|
||||
)
|
||||
off_v = (
|
||||
offs_n[:, None] * stride_vbs
|
||||
+ cur_kv_head * stride_vh
|
||||
+ offs_d[None, :] * stride_vd
|
||||
)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# block_mask is 0 when we're already past the current query length
|
||||
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
||||
|
||||
# compute query against itself (with causal mask)
|
||||
for start_n in tl.range(
|
||||
0,
|
||||
block_mask * (start_m + 1) * BLOCK_M,
|
||||
BLOCK_N,
|
||||
loop_unroll_factor=num_unroll_request,
|
||||
):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None]
|
||||
& ((start_n + offs_n[None, :]) < cur_batch_query_len),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk *= sm_scale
|
||||
# apply causal mask
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
if SLIDING_WINDOW > 0:
|
||||
qk = tl.where(
|
||||
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
|
||||
qk,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
p = tl.where(m_ij[:, None] == float("-inf"), 0.0, p)
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
# To prevent NaN from appearing in the first round
|
||||
alpha = tl.where(m_i == float("-inf"), 0.0, alpha)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :]
|
||||
& ((start_n + offs_n[:, None]) < cur_batch_query_len),
|
||||
other=0.0,
|
||||
)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / (l_i[:, None] + 1e-10)
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
tl.store(
|
||||
out_ptrs, acc, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_alibi(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Alibi_slopes,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr, # head size
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||
BLOCK_N: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
):
|
||||
# attn_bias[]
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
# cur_batch_seq_len: the length of prompts
|
||||
# cur_batch_ctx_len: the length of prefix
|
||||
# cur_batch_in_all_start_index: the start id of the dim=0
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
|
||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||
|
||||
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||
return
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :] * stride_qd
|
||||
)
|
||||
|
||||
dim_mask = tl.where(tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(
|
||||
tl.int1
|
||||
)
|
||||
|
||||
q = tl.load(
|
||||
Q + off_q,
|
||||
mask=dim_mask[None, :]
|
||||
& (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
|
||||
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = 0
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(
|
||||
B_Loc
|
||||
+ cur_batch * stride_b_loc_b
|
||||
+ ((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0,
|
||||
).to(tl.int64)
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs
|
||||
+ cur_kv_head * stride_k_cache_h
|
||||
+ (offs_d[:, None] // x) * stride_k_cache_d
|
||||
+ ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl
|
||||
+ (offs_d[:, None] % x) * stride_k_cache_x
|
||||
)
|
||||
off_v = (
|
||||
bn[:, None] * stride_v_cache_bs
|
||||
+ cur_kv_head * stride_v_cache_h
|
||||
+ offs_d[None, :] * stride_v_cache_d
|
||||
+ (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl
|
||||
)
|
||||
k_load = tl.load(
|
||||
K_cache + off_k,
|
||||
mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0,
|
||||
) # [D,N]
|
||||
|
||||
if k_load.dtype.is_fp8():
|
||||
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where(
|
||||
(start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
|
||||
)
|
||||
qk *= sm_scale
|
||||
|
||||
# load alibi
|
||||
alibi = (
|
||||
tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None]
|
||||
) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
|
||||
alibi,
|
||||
float("-inf"),
|
||||
)
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v_load = tl.load(
|
||||
V_cache + off_v,
|
||||
mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0,
|
||||
)
|
||||
if v_load.dtype.is_fp8():
|
||||
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||
else:
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision="ieee")
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (
|
||||
offs_n[None, :] * stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_d[:, None] * stride_kd
|
||||
)
|
||||
off_v = (
|
||||
offs_n[:, None] * stride_vbs
|
||||
+ cur_kv_head * stride_vh
|
||||
+ offs_d[None, :] * stride_vd
|
||||
)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
# init alibi
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = cur_batch_ctx_len
|
||||
# # init debugger
|
||||
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
|
||||
# offset_db_k = tl.arange(0, BLOCK_N)
|
||||
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None]
|
||||
& ((start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision="ieee")
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
|
||||
# load alibi
|
||||
alibi = (
|
||||
tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None]
|
||||
) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
|
||||
alibi,
|
||||
float("-inf"),
|
||||
)
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :]
|
||||
& ((start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0,
|
||||
)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision="ieee")
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(
|
||||
out_ptrs,
|
||||
acc,
|
||||
mask=dim_mask[None, :]
|
||||
& (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def context_attention_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
kv_cache_dtype: str,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_seq_len,
|
||||
max_input_len,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
skip_decode=False,
|
||||
fp8_out_scale=None,
|
||||
sinks=None,
|
||||
is_block_table_ptr: bool = False,
|
||||
):
|
||||
q_dtype_is_f32 = q.dtype is torch.float32
|
||||
|
||||
# Turing does have tensor core for float32 multiplication
|
||||
# use ieee as fallback for triton kernels work. There is also
|
||||
# warning on vllm/config.py to inform users this fallback
|
||||
# implementation
|
||||
IN_PRECISION = "ieee" if IS_TURING and q_dtype_is_f32 else None
|
||||
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
target_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||
|
||||
k_cache = k_cache.view(target_dtype)
|
||||
v_cache = v_cache.view(target_dtype)
|
||||
|
||||
if (
|
||||
k_cache.dtype == torch.uint8
|
||||
or v_cache.dtype == torch.uint8
|
||||
and kv_cache_dtype == "auto"
|
||||
):
|
||||
raise ValueError(
|
||||
"kv_cache_dtype='auto' unsupported for\
|
||||
FP8 KV Cache prefill kernel"
|
||||
)
|
||||
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
# round up Lk to a power of 2 - this is required for Triton block size
|
||||
Lk_padded = triton.next_power_of_2(Lk)
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||
|
||||
assert batch + 1 == len(b_start_loc)
|
||||
|
||||
# 0 means "disable"
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if is_block_table_ptr:
|
||||
kv_element_size = k_cache.element_size()
|
||||
block_byte_stride = k_cache.stride(0) * kv_element_size
|
||||
# The physical starting point of the obtained KV Cache Pool
|
||||
base_addr = k_cache.data_ptr()
|
||||
|
||||
mask = b_loc > 0
|
||||
processed_b_loc = torch.where(
|
||||
mask, (b_loc - base_addr) // block_byte_stride, b_loc
|
||||
).to(torch.int32)
|
||||
else:
|
||||
processed_b_loc = b_loc.to(torch.int32)
|
||||
|
||||
if alibi_slopes is not None:
|
||||
assert sinks is None, "Sinks arg is not supported with alibi"
|
||||
assert fp8_out_scale is None, "FP8 output not supported with alibi"
|
||||
# need to reduce num. blocks when using fp32
|
||||
# due to increased use of GPU shared memory
|
||||
# if q.dtype is torch.float32:
|
||||
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
|
||||
# batch, head,
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
_fwd_kernel_alibi[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
alibi_slopes,
|
||||
v_cache.shape[3],
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size]
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
BLOCK_N=BLOCK,
|
||||
SKIP_DECODE=skip_decode,
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
||||
extra_kargs = {}
|
||||
if current_platform.is_rocm():
|
||||
extra_kargs = {}
|
||||
|
||||
real_block_size = v_cache.shape[3]
|
||||
is_pow2 = real_block_size > 0 and (real_block_size & (real_block_size - 1) == 0)
|
||||
# For standard models involving powers of 2,
|
||||
# follow the original logic (Llama 128/64)
|
||||
# For non-standard models (Qwen3-next block_size 544), set to 32.
|
||||
if is_pow2:
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
else:
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 32
|
||||
|
||||
# TRITON_BLOCK_SIZE is kept at 32 to ensure
|
||||
# correct alignment logic when the kernel handles
|
||||
# non-standard sizes (such as 544).
|
||||
TRITON_BLOCK_SIZE = 32
|
||||
|
||||
grid_fn = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||
_fwd_kernel[grid_fn](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sinks,
|
||||
processed_b_loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
processed_b_loc.stride(0),
|
||||
processed_b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
stride_k_cache_bs=k_cache.stride(0),
|
||||
stride_k_cache_h=k_cache.stride(1),
|
||||
stride_k_cache_d=k_cache.stride(2),
|
||||
stride_k_cache_bl=k_cache.stride(3),
|
||||
stride_k_cache_x=k_cache.stride(4),
|
||||
stride_v_cache_bs=v_cache.stride(0),
|
||||
stride_v_cache_h=v_cache.stride(1),
|
||||
stride_v_cache_d=v_cache.stride(2),
|
||||
stride_v_cache_bl=v_cache.stride(3),
|
||||
BLOCK_SIZE=TRITON_BLOCK_SIZE,
|
||||
PHYSICAL_BLOCK_SIZE=real_block_size,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
SKIP_DECODE=skip_decode,
|
||||
USE_FP8=fp8_out_scale is not None,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
num_unroll_cache=4,
|
||||
num_unroll_request=1,
|
||||
num_warps=4,
|
||||
num_stages=1,
|
||||
USE_SINKS=sinks is not None,
|
||||
**extra_kargs,
|
||||
)
|
||||
return
|
||||
210
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Normal file
210
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
|
||||
def fp8_mqa_logits_torch(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
k_fp8, scale = kv
|
||||
seq_len_kv = k_fp8.shape[0]
|
||||
k = k_fp8.to(torch.bfloat16)
|
||||
q = q.to(torch.bfloat16)
|
||||
|
||||
mask_lo = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
|
||||
)
|
||||
mask_hi = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def rocm_fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
|
||||
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
|
||||
# path after aiter merge this kernel into main
|
||||
@lru_cache
|
||||
def has_mqa_logits_module():
|
||||
return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None
|
||||
|
||||
if rocm_aiter_ops.is_enabled() and has_mqa_logits_module():
|
||||
from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits
|
||||
|
||||
kv, scale = kv
|
||||
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
else:
|
||||
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
|
||||
def fp8_paged_mqa_logits_torch(
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
max_model_len: int,
|
||||
):
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
batch_size, next_n, _, dim = q.size()
|
||||
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
|
||||
scale = scale.contiguous().view(torch.float)
|
||||
q = q.float()
|
||||
kv_cache = kv_cache.view(fp8_dtype).float() * scale
|
||||
num_block, block_size, _, dim = kv_cache.size()
|
||||
logits = torch.full(
|
||||
[batch_size * next_n, max_model_len],
|
||||
float("-inf"),
|
||||
device=q.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
context_lens = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens[i]
|
||||
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
|
||||
weight_slice = (
|
||||
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||
)
|
||||
for block_rk in range(cdiv(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
k_offsets = torch.arange(
|
||||
block_rk * block_size, (block_rk + 1) * block_size, device="cuda"
|
||||
)
|
||||
mask = (k_offsets[None, :] < context_len) & (
|
||||
k_offsets[None, :] <= q_offsets[:, None]
|
||||
)
|
||||
s = torch.where(
|
||||
mask[None, :, :],
|
||||
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||
logits.dtype
|
||||
),
|
||||
float("-inf"),
|
||||
)
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[
|
||||
i * next_n : (i + 1) * next_n,
|
||||
block_rk * block_size : (block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def rocm_fp8_paged_mqa_logits(
|
||||
q_fp8: torch.Tensor,
|
||||
kv_cache_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
schedule_metadata: torch.Tensor,
|
||||
max_model_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits using paged KV-cache.
|
||||
|
||||
Args:
|
||||
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
||||
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||
4 bytes per (block,pos) store the `float` dequant scale.
|
||||
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
for each batch element.
|
||||
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||
block indices to physical blocks in the paged cache.
|
||||
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
||||
used to distribute work across SMs.
|
||||
max_model_len: Maximum sequence length used to size the logits output.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||
`torch.float32`.
|
||||
"""
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1
|
||||
|
||||
batch_size, next_n, heads, _ = q_fp8.shape
|
||||
out_qk = torch.full(
|
||||
(heads, batch_size * next_n, max_model_len),
|
||||
float("-inf"),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
deepgemm_fp8_paged_mqa_logits_stage1(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
out_qk,
|
||||
context_lens,
|
||||
block_tables,
|
||||
max_model_len,
|
||||
)
|
||||
return out_qk.sum(dim=0)
|
||||
else:
|
||||
return fp8_paged_mqa_logits_torch(
|
||||
q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
|
||||
)
|
||||
709
vllm/v1/attention/ops/triton_decode_attention.py
Normal file
709
vllm/v1/attention/ops/triton_decode_attention.py
Normal file
@@ -0,0 +1,709 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
|
||||
# which was originally adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
||||
|
||||
# Changes:
|
||||
# - Add support for page size >= 1.
|
||||
|
||||
# Copyright 2025 vLLM Team
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Memory-efficient attention for decoding.
|
||||
It supports page size >= 1.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
is_hip_ = current_platform.is_rocm()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Only print the following warnings when triton version < 3.2.0.
|
||||
# The issue won't affect performance or accuracy.
|
||||
if version.parse(triton.__version__) < version.parse("3.2.0"):
|
||||
logger.warning(
|
||||
"The following error message 'operation scheduled before its operands' "
|
||||
"can be ignored."
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lk
|
||||
mask_dv = offs_dv < Lv
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_req_idx = cur_batch
|
||||
|
||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
||||
|
||||
e_max = -float("inf")
|
||||
e_sum = 0.0
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens
|
||||
+ stride_req_to_tokens_b * cur_batch_req_idx
|
||||
+ offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_buf_k = (
|
||||
kv_loc[:, None] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
qk = tl.sum(q[None, :] * k, 1)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
||||
|
||||
offs_buf_v = (
|
||||
kv_loc[:, None] * stride_buf_vbs
|
||||
+ cur_kv_head * stride_buf_vh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max)
|
||||
acc *= re_scale
|
||||
acc += tl.sum(p[:, None] * v, 0)
|
||||
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 0)
|
||||
e_max = n_e_max
|
||||
|
||||
offs_mid_o = (
|
||||
cur_batch * stride_mid_ob
|
||||
+ cur_head * stride_mid_oh
|
||||
+ split_kv_id * stride_mid_os
|
||||
+ offs_dv
|
||||
)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o,
|
||||
acc / e_sum,
|
||||
mask=(mask_dv),
|
||||
)
|
||||
|
||||
offs_mid_o_1 = (
|
||||
cur_batch * stride_mid_ob
|
||||
+ cur_head * stride_mid_oh
|
||||
+ split_kv_id * stride_mid_os
|
||||
+ Lv
|
||||
)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o_1,
|
||||
e_max + tl.log(e_sum),
|
||||
)
|
||||
|
||||
|
||||
def _decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 64 if not is_hip_ else 8
|
||||
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head_num, NUM_KV_SPLITS)
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||
|
||||
num_warps = 4
|
||||
if kv_group_num != 1:
|
||||
num_warps = 1 if is_hip_ else 2
|
||||
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
_fwd_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_N=BLOCK,
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
PAGE_SIZE=page_size,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=num_warps,
|
||||
num_stages=2,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_grouped_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
kv_group_num: tl.constexpr,
|
||||
q_head_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DPE: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_H: tl.constexpr,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head_id = tl.program_id(1)
|
||||
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H if kv_group_num > BLOCK_H else kv_group_num
|
||||
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||
mask_h = mask_h & (cur_head < q_head_num)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lk
|
||||
mask_dv = offs_dv < Lv
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_req_idx = cur_batch
|
||||
|
||||
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
||||
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
mask_dpe = offs_dpe < Lk
|
||||
off_qpe = (
|
||||
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
|
||||
)
|
||||
qpe = tl.load(
|
||||
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
|
||||
)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
||||
|
||||
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
||||
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens
|
||||
+ stride_req_to_tokens_b * cur_batch_req_idx
|
||||
+ offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_buf_k = (
|
||||
kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
qk = tl.dot(q, k.to(q.dtype))
|
||||
if BLOCK_DPE > 0:
|
||||
offs_buf_kpe = (
|
||||
kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_buf_kpe,
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe.to(qpe.dtype))
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(
|
||||
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
|
||||
)
|
||||
|
||||
offs_buf_v = (
|
||||
kv_loc[:, None] * stride_buf_vbs
|
||||
+ cur_kv_head * stride_buf_vh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
acc *= re_scale[:, None]
|
||||
acc += tl.dot(p.to(v.dtype), v)
|
||||
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 1)
|
||||
e_max = n_e_max
|
||||
|
||||
offs_mid_o = (
|
||||
cur_batch * stride_mid_ob
|
||||
+ cur_head[:, None] * stride_mid_oh
|
||||
+ split_kv_id * stride_mid_os
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o,
|
||||
acc / e_sum[:, None],
|
||||
mask=(mask_h[:, None]) & (mask_dv[None, :]),
|
||||
)
|
||||
|
||||
offs_mid_o_1 = (
|
||||
cur_batch * stride_mid_ob
|
||||
+ cur_head * stride_mid_oh
|
||||
+ split_kv_id * stride_mid_os
|
||||
+ Lv
|
||||
)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o_1,
|
||||
e_max + tl.log(e_sum),
|
||||
mask=mask_h,
|
||||
)
|
||||
|
||||
|
||||
def _decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 32
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
# [TODO] work around shmem limit on MI3xx
|
||||
if is_hip_ and Lk >= 576:
|
||||
BLOCK = 16
|
||||
|
||||
if Lk == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
elif Lk == 288:
|
||||
BLOCK_DMODEL = 256
|
||||
BLOCK_DPE = 32
|
||||
else:
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||
|
||||
BLOCK_H = 16
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
grid = (
|
||||
batch,
|
||||
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
||||
NUM_KV_SPLITS,
|
||||
)
|
||||
|
||||
extra_kargs = {}
|
||||
num_stages = 2
|
||||
if is_hip_:
|
||||
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
num_stages = 1
|
||||
|
||||
_fwd_grouped_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
q_head_num=head_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_N=BLOCK,
|
||||
BLOCK_H=BLOCK_H,
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
PAGE_SIZE=page_size,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=4,
|
||||
num_stages=num_stages,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage2(
|
||||
Mid_O,
|
||||
o,
|
||||
lse,
|
||||
B_Seqlen,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_lse_bs,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lv
|
||||
|
||||
e_sum = 0.0
|
||||
e_max = -float("inf")
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
|
||||
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
|
||||
|
||||
for split_kv_id in range(0, NUM_KV_SPLITS):
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
tv = tl.load(
|
||||
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
|
||||
)
|
||||
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
|
||||
n_e_max = tl.maximum(tlogic, e_max)
|
||||
|
||||
old_scale = tl.exp(e_max - n_e_max)
|
||||
acc *= old_scale
|
||||
exp_logic = tl.exp(tlogic - n_e_max)
|
||||
acc += exp_logic * tv
|
||||
|
||||
e_sum = e_sum * old_scale + exp_logic
|
||||
e_max = n_e_max
|
||||
|
||||
tl.store(
|
||||
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
||||
acc / e_sum,
|
||||
mask=mask_d,
|
||||
)
|
||||
lse_val = e_max + tl.log(e_sum)
|
||||
tl.store(
|
||||
lse + cur_batch * stride_lse_bs + cur_head,
|
||||
lse_val,
|
||||
)
|
||||
|
||||
|
||||
def _decode_softmax_reducev_fwd(
|
||||
logits,
|
||||
q,
|
||||
o,
|
||||
lse,
|
||||
v_buffer,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
):
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip_:
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
|
||||
grid = (batch, head_num)
|
||||
_fwd_kernel_stage2[grid](
|
||||
logits,
|
||||
o,
|
||||
lse,
|
||||
b_seq_len,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
lse.stride(0),
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
Lv=Lv,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size=1,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
assert num_kv_splits == attn_logits.shape[2]
|
||||
kv_group_num = q.shape[1] // v_buffer.shape[-2]
|
||||
|
||||
if kv_group_num == 1:
|
||||
# MHA
|
||||
decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
116
vllm/v1/attention/ops/triton_merge_attn_states.py
Normal file
116
vllm/v1/attention/ops/triton_merge_attn_states.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
# can be used to combine partial attention results (in the split-KV case)
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
# We assume the output stride on num_head is not always as same as the
|
||||
# `suffix_output` and `prefix_output`, as them might be padded by the attention
|
||||
# backend.
|
||||
prefix_head_stride = prefix_output.stride(1)
|
||||
output_head_stride = output.stride(1)
|
||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
output_lse,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
prefix_head_stride,
|
||||
output_head_stride,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_attn_states_kernel(
|
||||
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
output_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefix_head_stride,
|
||||
output_head_stride,
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||
|
||||
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
|
||||
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
|
||||
# If we see an inf assume FA2 and convert inf to -inf for consistency
|
||||
# and correctness. Inf generally doesn't make sense in this context outside
|
||||
# of undefined-behavior/FA2-case, so I think this a safe assumption.
|
||||
p_lse = float("-inf") if p_lse == float("inf") else p_lse
|
||||
s_lse = float("-inf") if s_lse == float("inf") else s_lse
|
||||
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
# Will reuse precomputed Exp values for scale factor computation.
|
||||
p_se = tl.exp(p_lse)
|
||||
s_se = tl.exp(s_lse)
|
||||
out_se = p_se + s_se
|
||||
|
||||
if OUTPUT_LSE:
|
||||
out_lse = tl.log(out_se) + max_lse
|
||||
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(
|
||||
prefix_output
|
||||
+ token_idx * num_heads * prefix_head_stride
|
||||
+ head_idx * prefix_head_stride
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
s_out = tl.load(
|
||||
suffix_output
|
||||
+ token_idx * num_heads * prefix_head_stride
|
||||
+ head_idx * prefix_head_stride
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Be careful with the numerical stability.
|
||||
# We should compute the scale first, and then multiply it with the output.
|
||||
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
||||
p_scale = p_se / out_se
|
||||
s_scale = s_se / out_se
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(
|
||||
output
|
||||
+ token_idx * num_heads * output_head_stride
|
||||
+ head_idx * output_head_stride
|
||||
+ head_arange,
|
||||
out,
|
||||
mask=head_mask,
|
||||
)
|
||||
271
vllm/v1/attention/ops/triton_prefill_attention.py
Normal file
271
vllm/v1/attention/ops/triton_prefill_attention.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/97cb762bb65ebf05025eb342de03c184660427a3/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
|
||||
# Changes:
|
||||
# - Add support for sliding window attention
|
||||
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Memory-efficient attention for prefill.
|
||||
It supports page size = 1.
|
||||
"""
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
SLIDING_WINDOW_Q: tl.constexpr,
|
||||
SLIDING_WINDOW_K: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
|
||||
|
||||
mask_d = offs_d < Lk
|
||||
|
||||
q = tl.load(
|
||||
Q + off_q,
|
||||
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
# Calculate the end position for attention computation
|
||||
end_n = cur_batch_seq_len
|
||||
|
||||
# Apply causal attention pruning and sliding window attention pruning
|
||||
end_n = tl.minimum(end_n, (start_m + 1) * BLOCK_M) if IS_CAUSAL else end_n
|
||||
|
||||
# Calculate the start position for backward sliding window
|
||||
start_n_limit = 0
|
||||
end_n_limit = block_mask * end_n
|
||||
|
||||
for start_n in range(start_n_limit, end_n_limit, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Apply attention mask (causal + bidirectional sliding window)
|
||||
# Position indices in the sequence
|
||||
pos_q = offs_m[:, None] # Query positions [BLOCK_M, 1]
|
||||
pos_k = start_n + offs_n[None, :] # Key positions [1, BLOCK_N]
|
||||
|
||||
# Valid sequence mask
|
||||
mask = pos_k < cur_batch_seq_len
|
||||
# Causal mask
|
||||
if IS_CAUSAL:
|
||||
mask &= pos_q >= pos_k
|
||||
|
||||
# Bidirectional sliding window masks
|
||||
sliding_mask_q = (
|
||||
pos_q - pos_k <= SLIDING_WINDOW_Q if SLIDING_WINDOW_Q > 0 else None
|
||||
)
|
||||
sliding_mask_k = (
|
||||
pos_k - pos_q <= SLIDING_WINDOW_K if SLIDING_WINDOW_K > 0 else None
|
||||
)
|
||||
if sliding_mask_q is not None:
|
||||
mask &= sliding_mask_q
|
||||
if sliding_mask_k is not None:
|
||||
mask &= sliding_mask_k
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.where(mask, 0, float("-inf"))
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_ij_valid_mask = m_ij > float("-inf")
|
||||
m_ij_masked = tl.where(m_ij_valid_mask, m_ij, 0.0)
|
||||
# -- compute p and l_ij --
|
||||
p = tl.exp(qk - m_ij_masked[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
m_i_new_mask = m_i_new > float("-inf")
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
# mask alpha and beta for sliding window
|
||||
alpha = tl.where(m_i_new_mask, alpha, 1.0)
|
||||
beta = tl.where(m_i_new_mask, beta, 0.0)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# For sliding window there's a chance the l_i_new is 0 due to masking
|
||||
# the entire row. We need to set l_i_new 1 to avoid zero division
|
||||
l_i_new_mask = (l_i_new != 0.0) & (m_i_new_mask > float("-inf"))
|
||||
l_i_new_safe = tl.where(l_i_new_mask, l_i_new, 1.0)
|
||||
p_scale = beta / l_i_new_safe
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new_safe * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(
|
||||
out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
|
||||
)
|
||||
|
||||
|
||||
def get_block_size(dtype: torch.dtype) -> int:
|
||||
if dtype == torch.float32:
|
||||
return 32
|
||||
elif current_platform.is_cuda_alike() and current_platform.has_device_capability(
|
||||
80
|
||||
):
|
||||
return 128
|
||||
else:
|
||||
return 64
|
||||
|
||||
|
||||
def context_attention_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_input_len,
|
||||
is_causal=True,
|
||||
sliding_window_q=None,
|
||||
sliding_window_k=None,
|
||||
):
|
||||
"""
|
||||
q, k, v: [b * s, head, head_dim]
|
||||
b_start_loc: [b]
|
||||
b_seq_len: [b]
|
||||
out: [b * s, head, head_dim]
|
||||
"""
|
||||
BLOCK = get_block_size(q.dtype)
|
||||
|
||||
Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
sliding_window_q = sliding_window_q if sliding_window_q is not None else 0
|
||||
sliding_window_k = sliding_window_k if sliding_window_k is not None else 0
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=triton.next_power_of_2(Lk),
|
||||
BLOCK_N=BLOCK,
|
||||
IS_CAUSAL=is_causal,
|
||||
SLIDING_WINDOW_Q=sliding_window_q,
|
||||
SLIDING_WINDOW_K=sliding_window_k,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
Lk=Lk,
|
||||
)
|
||||
395
vllm/v1/attention/ops/triton_reshape_and_cache_flash.py
Normal file
395
vllm/v1/attention/ops/triton_reshape_and_cache_flash.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reshape_and_cache_kernel_flash(
|
||||
key_ptr, # [num_tokens, num_heads, head_size]
|
||||
value_ptr, # [num_tokens, num_heads, head_size]
|
||||
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
# strides
|
||||
key_stride: tl.int64,
|
||||
value_stride: tl.int64,
|
||||
block_stride: tl.int64,
|
||||
head_stride: tl.int64,
|
||||
dim_stride_k: tl.int64,
|
||||
dim_stride_v: tl.int64,
|
||||
page_stride: tl.int64,
|
||||
num_heads: tl.constexpr,
|
||||
head_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
x: tl.constexpr,
|
||||
USE_HEAD_MAJOR_LAYOUT: tl.constexpr,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE: tl.constexpr,
|
||||
# tune parameters
|
||||
TILE_SIZE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(axis=0)
|
||||
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
|
||||
if slot_idx < 0:
|
||||
# Padding token that should be ignored.
|
||||
return
|
||||
|
||||
block_idx = slot_idx // block_size
|
||||
block_offset = slot_idx % block_size
|
||||
|
||||
tile_i = tl.program_id(axis=1)
|
||||
tile_offs = tl.arange(0, TILE_SIZE)
|
||||
tile_pos = tile_i * TILE_SIZE + tile_offs
|
||||
src_key_idx = token_idx * key_stride
|
||||
src_value_idx = token_idx * value_stride
|
||||
|
||||
if USE_HEAD_MAJOR_LAYOUT:
|
||||
# Decompose the tile index back into head and dim coordinates.
|
||||
cur_head = tile_pos // head_size
|
||||
cur_dim = tile_pos % head_size
|
||||
# Value addressing (4D): [Block, Head, Dim, Slot]
|
||||
tgt_idx_v = (
|
||||
block_idx * block_stride
|
||||
+ cur_head * head_stride
|
||||
+ cur_dim * dim_stride_v
|
||||
+ block_offset * 1
|
||||
)
|
||||
# Key addressing (5D): [Block, Head, Dim//8, Slot, 8]
|
||||
tgt_idx_k = (
|
||||
block_idx * block_stride
|
||||
+ cur_head * head_stride
|
||||
+ (cur_dim // x) * dim_stride_k
|
||||
+ block_offset * x
|
||||
+ (cur_dim % x)
|
||||
)
|
||||
else:
|
||||
tgt_base = block_idx * block_stride + block_offset * page_stride
|
||||
tgt_idx_k = tgt_base + tile_pos
|
||||
tgt_idx_v = tgt_base + tile_pos
|
||||
|
||||
# [TILE_SIZE]
|
||||
key_load = tl.load(
|
||||
key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
|
||||
)
|
||||
if FP8_KV_CACHE:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the key_cache_ptr.dtype.element_ty
|
||||
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
|
||||
else:
|
||||
key_tile = key_load
|
||||
|
||||
# [TILE_SIZE]
|
||||
value_load = tl.load(
|
||||
value_ptr + src_value_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
|
||||
)
|
||||
if FP8_KV_CACHE:
|
||||
if value_load.dtype.is_fp8():
|
||||
value_tile = value_load
|
||||
else:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the value_cache_ptr.dtype.element_ty
|
||||
value_tile = value_load / tl.load(v_scale)
|
||||
else:
|
||||
value_tile = value_load
|
||||
|
||||
tl.store(
|
||||
key_cache_ptr + tgt_idx_k,
|
||||
key_tile,
|
||||
mask=tile_pos < (num_heads * head_size),
|
||||
)
|
||||
tl.store(
|
||||
value_cache_ptr + tgt_idx_v,
|
||||
value_tile,
|
||||
mask=tile_pos < (num_heads * head_size),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def triton_reshape_and_cache_flash(
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
# [num_blocks, block_size, num_heads, head_size]
|
||||
key_cache: torch.Tensor,
|
||||
# [num_blocks, block_size, num_heads, head_size]
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor, # [num_tokens]
|
||||
kv_cache_dtype: str, # "auto", "fp8"
|
||||
k_scale: torch.Tensor, # float32
|
||||
v_scale: torch.Tensor, # float32
|
||||
):
|
||||
num_heads = key.shape[1]
|
||||
head_size = key.shape[2]
|
||||
|
||||
use_head_major_layout = key_cache.ndim == 5
|
||||
if use_head_major_layout:
|
||||
block_size = key_cache.shape[3]
|
||||
x = key_cache.shape[4]
|
||||
head_stride = key_cache.stride(1)
|
||||
dim_stride_k = key_cache.stride(2)
|
||||
dim_stride_v = value_cache.stride(2)
|
||||
else:
|
||||
block_size = key_cache.shape[1]
|
||||
x = 1
|
||||
dim_stride_k = 0
|
||||
dim_stride_v = 0
|
||||
head_stride = key_cache.stride()[2]
|
||||
n = num_heads * head_size
|
||||
key_stride = key.stride()[0]
|
||||
value_stride = value.stride()[0]
|
||||
block_stride = key_cache.stride()[0]
|
||||
page_stride = key_cache.stride()[1]
|
||||
|
||||
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
|
||||
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||
)
|
||||
kv_cache_torch_dtype = (
|
||||
current_platform.fp8_dtype()
|
||||
if kv_cache_dtype.startswith("fp8")
|
||||
else key_cache.dtype
|
||||
)
|
||||
|
||||
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
|
||||
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||
assert kv_cache_dtype != torch.uint8, (
|
||||
"explicit fp8 cast and store to "
|
||||
"uint8 is not supported by triton reshape_and_cache_flash"
|
||||
)
|
||||
|
||||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.uint8,
|
||||
torch.float8_e4m3fnuz,
|
||||
], (
|
||||
"unsupported dtype of KV cache tensor, got "
|
||||
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, "
|
||||
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
||||
)
|
||||
|
||||
# heuristics instead of autotuning
|
||||
TILE_SIZE = min(2048, triton.next_power_of_2(n))
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
num_stages = 4
|
||||
num_warps = 8
|
||||
else: # cuda
|
||||
num_stages = 10
|
||||
num_warps = 16
|
||||
if torch.cuda.get_device_capability(key.device)[0] < 9:
|
||||
TILE_SIZE = min(512, TILE_SIZE)
|
||||
|
||||
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
|
||||
# using cudagraphs
|
||||
grid = lambda meta: (
|
||||
slot_mapping.shape[0],
|
||||
triton.cdiv(n, meta["TILE_SIZE"]),
|
||||
)
|
||||
|
||||
reshape_and_cache_kernel_flash[grid](
|
||||
key_ptr=key,
|
||||
value_ptr=value,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
slot_mapping_ptr=slot_mapping,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
# strides
|
||||
key_stride=key_stride,
|
||||
value_stride=value_stride,
|
||||
block_stride=block_stride,
|
||||
head_stride=head_stride,
|
||||
dim_stride_k=dim_stride_k,
|
||||
dim_stride_v=dim_stride_v,
|
||||
page_stride=page_stride,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
x=x,
|
||||
USE_HEAD_MAJOR_LAYOUT=use_head_major_layout,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||
# autotune parameters
|
||||
TILE_SIZE=TILE_SIZE,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reshape_and_cache_kernel_flash_diffkv(
|
||||
key_ptr, # [num_tokens, num_heads, head_size]
|
||||
value_ptr, # [num_tokens, num_heads, head_size_v]
|
||||
kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v]
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
# strides
|
||||
key_stride: tl.int64,
|
||||
value_stride: tl.int64,
|
||||
block_stride: tl.int64,
|
||||
page_stride: tl.int64,
|
||||
num_heads: tl.constexpr,
|
||||
head_size_k: tl.constexpr,
|
||||
head_size_v: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE: tl.constexpr,
|
||||
# tune parameters
|
||||
TILE_SIZE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(axis=0)
|
||||
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
|
||||
if slot_idx < 0:
|
||||
# Padding token that should be ignored.
|
||||
return
|
||||
|
||||
tile_i = tl.program_id(axis=1)
|
||||
tile_offs = tl.arange(0, TILE_SIZE)
|
||||
|
||||
block_idx = slot_idx // block_size
|
||||
block_offset = slot_idx % block_size
|
||||
|
||||
src_key_idx = token_idx * key_stride + tile_i * head_size_k
|
||||
src_value_idx = token_idx * value_stride + tile_i * head_size_v
|
||||
|
||||
tgt_idx = (
|
||||
block_idx * block_stride
|
||||
+ block_offset * page_stride
|
||||
+ tile_i * (head_size_k + head_size_v)
|
||||
)
|
||||
|
||||
# [TILE_SIZE]
|
||||
key_load = tl.load(key_ptr + src_key_idx + tile_offs, mask=tile_offs < head_size_k)
|
||||
if FP8_KV_CACHE:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the key_cache_ptr.dtype.element_ty
|
||||
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
|
||||
else:
|
||||
key_tile = key_load
|
||||
|
||||
# [TILE_SIZE]
|
||||
value_load = tl.load(
|
||||
value_ptr + src_value_idx + tile_offs, mask=tile_offs < head_size_v
|
||||
)
|
||||
if FP8_KV_CACHE:
|
||||
if value_load.dtype.is_fp8():
|
||||
value_tile = value_load
|
||||
else:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the value_cache_ptr.dtype.element_ty
|
||||
value_tile = value_load / tl.load(v_scale)
|
||||
else:
|
||||
value_tile = value_load
|
||||
|
||||
tl.store(
|
||||
kv_cache_ptr + tgt_idx + tile_offs,
|
||||
key_tile,
|
||||
mask=tile_offs < head_size_k,
|
||||
)
|
||||
tl.store(
|
||||
kv_cache_ptr + tgt_idx + head_size_k + tile_offs,
|
||||
value_tile,
|
||||
mask=tile_offs < head_size_v,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def triton_reshape_and_cache_flash_diffkv(
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads, head_size_v]
|
||||
# [num_blocks, block_size, num_heads, head_size + head_size_v]
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor, # [num_tokens]
|
||||
kv_cache_dtype: str, # "auto", "fp8"
|
||||
k_scale: torch.Tensor, # float32
|
||||
v_scale: torch.Tensor, # float32
|
||||
):
|
||||
num_heads = key.shape[1]
|
||||
head_size_k = key.shape[2]
|
||||
head_size_v = value.shape[2]
|
||||
block_size = kv_cache.shape[1]
|
||||
|
||||
k_stride = key.stride()[0]
|
||||
v_stride = value.stride()[0]
|
||||
block_stride = kv_cache.stride()[0]
|
||||
page_stride = kv_cache.stride()[1]
|
||||
|
||||
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
|
||||
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||
)
|
||||
kv_cache_torch_dtype = (
|
||||
current_platform.fp8_dtype()
|
||||
if kv_cache_dtype.startswith("fp8")
|
||||
else kv_cache.dtype
|
||||
)
|
||||
|
||||
if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
|
||||
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||
kv_cache = kv_cache.view(kv_cache_torch_dtype)
|
||||
assert kv_cache_dtype != torch.uint8, (
|
||||
"explicit fp8 cast and store to "
|
||||
"uint8 is not supported by triton reshape_and_cache_flash_diffkv"
|
||||
)
|
||||
|
||||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.uint8,
|
||||
torch.float8_e4m3fnuz,
|
||||
], (
|
||||
"unsupported dtype of KV cache tensor, got "
|
||||
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, "
|
||||
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
||||
)
|
||||
|
||||
# heuristics instead of autotuning
|
||||
TILE_SIZE = max(head_size_k, head_size_v)
|
||||
TILE_SIZE = triton.next_power_of_2(TILE_SIZE)
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
num_stages = 4
|
||||
num_warps = 8
|
||||
else: # cuda
|
||||
num_stages = 10
|
||||
num_warps = 16
|
||||
|
||||
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
|
||||
# using cudagraphs
|
||||
grid = lambda meta: (
|
||||
slot_mapping.shape[0],
|
||||
num_heads,
|
||||
)
|
||||
|
||||
reshape_and_cache_kernel_flash_diffkv[grid](
|
||||
key_ptr=key,
|
||||
value_ptr=value,
|
||||
kv_cache_ptr=kv_cache,
|
||||
slot_mapping_ptr=slot_mapping,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
# strides
|
||||
key_stride=k_stride,
|
||||
value_stride=v_stride,
|
||||
block_stride=block_stride,
|
||||
page_stride=page_stride,
|
||||
num_heads=num_heads,
|
||||
head_size_k=head_size_k,
|
||||
head_size_v=head_size_v,
|
||||
block_size=block_size,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||
# autotune parameters
|
||||
TILE_SIZE=TILE_SIZE,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
1065
vllm/v1/attention/ops/triton_unified_attention.py
Normal file
1065
vllm/v1/attention/ops/triton_unified_attention.py
Normal file
File diff suppressed because it is too large
Load Diff
185
vllm/v1/attention/ops/vit_attn_wrappers.py
Normal file
185
vllm/v1/attention/ops/vit_attn_wrappers.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains ops for ViT attention to be compatible with torch.compile
|
||||
as there are operations here not supported by torch.compile (for instance,
|
||||
`.item()` in flash attention)
|
||||
|
||||
Using these ops and wrapping vision blocks with `torch.compile` can speed up
|
||||
throughput in vision models by ~5% relative on H100, and improve token
|
||||
latencies by ~7% (see qwen2_5_vl for example usage)
|
||||
|
||||
To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
|
||||
"""
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def flash_attn_maxseqlen_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int | None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
kwargs = {}
|
||||
if is_rocm_aiter:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
|
||||
|
||||
if not current_platform.is_rocm() and fa_version is not None:
|
||||
kwargs["fa_version"] = fa_version
|
||||
|
||||
q_len = q.size(1)
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = torch.arange(
|
||||
0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
|
||||
)
|
||||
max_seqlen = q_len if max_seqlen is None else max_seqlen.item()
|
||||
|
||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
output = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
softmax_scale=scale,
|
||||
**kwargs,
|
||||
)
|
||||
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
|
||||
return context_layer
|
||||
|
||||
|
||||
def flash_attn_maxseqlen_wrapper_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int | None,
|
||||
scale: float | None,
|
||||
cu_seqlens: torch.Tensor | None,
|
||||
max_seqlen: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flash_attn_maxseqlen_wrapper",
|
||||
op_func=flash_attn_maxseqlen_wrapper,
|
||||
fake_impl=flash_attn_maxseqlen_wrapper_fake,
|
||||
)
|
||||
|
||||
|
||||
def vit_flash_attn_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int | None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
batch_size,
|
||||
is_rocm_aiter,
|
||||
fa_version,
|
||||
scale,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
|
||||
|
||||
def apply_sdpa(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Input shape:
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
|
||||
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, scale=scale)
|
||||
output = einops.rearrange(output, "b h s d -> b s h d ")
|
||||
return output
|
||||
|
||||
|
||||
# TODO: Once we have a torch 2.10, we can use tensor slices
|
||||
# so we won't need to wrap this in custom ops
|
||||
def torch_sdpa_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Never remove the contiguous logic for ROCm
|
||||
# Without it, hallucinations occur with the backend
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
if cu_seqlens is None:
|
||||
return apply_sdpa(q, k, v, scale=scale)
|
||||
|
||||
outputs = []
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_chunks = torch.split(q, lens, dim=1)
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
output_i = apply_sdpa(q_i, k_i, v_i, scale=scale)
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
return context_layer
|
||||
|
||||
|
||||
def torch_sdpa_wrapper_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float | None,
|
||||
cu_seqlens: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="torch_sdpa_wrapper",
|
||||
op_func=torch_sdpa_wrapper,
|
||||
fake_impl=torch_sdpa_wrapper_fake,
|
||||
)
|
||||
|
||||
|
||||
def vit_torch_sdpa_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, scale, cu_seqlens)
|
||||
145
vllm/v1/attention/selector.py
Normal file
145
vllm/v1/attention/selector.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import cache
|
||||
from typing import NamedTuple, cast, get_args
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionType
|
||||
from vllm.v1.attention.backends.registry import (
|
||||
MAMBA_TYPE_TO_BACKEND_MAP,
|
||||
MambaAttentionBackendEnum,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AttentionSelectorConfig(NamedTuple):
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
kv_cache_dtype: CacheDType | None
|
||||
block_size: int | None
|
||||
use_mla: bool = False
|
||||
has_sink: bool = False
|
||||
use_sparse: bool = False
|
||||
use_mm_prefix: bool = False
|
||||
attn_type: str = AttentionType.DECODER
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"AttentionSelectorConfig(head_size={self.head_size}, "
|
||||
f"dtype={self.dtype}, "
|
||||
f"kv_cache_dtype={self.kv_cache_dtype}, "
|
||||
f"block_size={self.block_size}, "
|
||||
f"use_mla={self.use_mla}, "
|
||||
f"has_sink={self.has_sink}, "
|
||||
f"use_sparse={self.use_sparse}, "
|
||||
f"use_mm_prefix={self.use_mm_prefix}, "
|
||||
f"attn_type={self.attn_type})"
|
||||
)
|
||||
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int | None,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
use_mm_prefix: bool = False,
|
||||
attn_type: str | None = None,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
|
||||
if kv_cache_dtype is not None:
|
||||
valid_cache_dtypes = get_args(CacheDType)
|
||||
assert kv_cache_dtype in valid_cache_dtypes, (
|
||||
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
|
||||
f"Valid values are: {valid_cache_dtypes}"
|
||||
)
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
backend_enum = vllm_config.attention_config.backend
|
||||
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
|
||||
block_size=block_size,
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
use_sparse=use_sparse,
|
||||
use_mm_prefix=use_mm_prefix,
|
||||
attn_type=attn_type or AttentionType.DECODER,
|
||||
)
|
||||
|
||||
return _cached_get_attn_backend(
|
||||
backend=backend_enum,
|
||||
attn_selector_config=attn_selector_config,
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
backend,
|
||||
attn_selector_config: AttentionSelectorConfig,
|
||||
) -> type[AttentionBackend]:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
backend,
|
||||
attn_selector_config=attn_selector_config,
|
||||
)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
)
|
||||
backend = resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
# Adjust kv cache layout if the selected backend requires a specific one
|
||||
required_layout = backend.get_required_kv_cache_layout()
|
||||
if required_layout is not None:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
set_kv_cache_layout(required_layout)
|
||||
logger.info(
|
||||
"Using %s KV cache layout for %s backend.",
|
||||
required_layout,
|
||||
backend.get_name(),
|
||||
)
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
def get_mamba_attn_backend(
|
||||
mamba_type: str,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Select which mamba attention backend to use and lazily import it."""
|
||||
return _cached_get_mamba_attn_backend(mamba_type)
|
||||
|
||||
|
||||
@cache
|
||||
def _cached_get_mamba_attn_backend(
|
||||
mamba_type: str,
|
||||
) -> type[AttentionBackend]:
|
||||
assert mamba_type and isinstance(mamba_type, str)
|
||||
|
||||
selected_backend = None
|
||||
try:
|
||||
backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
|
||||
selected_backend = MambaAttentionBackendEnum[backend_name]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Invalid mamba attention backend type: '{backend_name}'. Valid "
|
||||
f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
|
||||
) from e
|
||||
|
||||
mamba_attn_backend = selected_backend.get_class()
|
||||
return mamba_attn_backend
|
||||
Reference in New Issue
Block a user