[HMA]Fix corner case when hybrid page_size can not be evenly divided issue (blk_size=64,tp=4) (#37467)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -851,6 +851,7 @@ def test_hybrid_attention_mamba_tensor_shapes():
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
current_platform.update_block_size_for_backend(vllm_config)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
|
||||
available_memory = 5 * GiB_bytes
|
||||
@@ -1306,6 +1307,7 @@ def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
|
||||
assert fwd_context is not None
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
current_platform.update_block_size_for_backend(vllm_config)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
|
||||
available_memory = 5 * GiB_bytes
|
||||
|
||||
@@ -38,6 +38,8 @@ class CacheConfig:
|
||||
Accepts None (meaning "use default"). After construction, always int."""
|
||||
user_specified_block_size: bool = field(default=False, init=False)
|
||||
"""Whether block_size was explicitly provided. Derived automatically."""
|
||||
user_specified_mamba_block_size: bool = field(default=False, init=False)
|
||||
"""Whether mamba_block_size was explicitly provided. Derived automatically."""
|
||||
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
|
||||
"""The fraction of GPU memory to be used for the model executor, which can
|
||||
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
|
||||
@@ -182,6 +184,7 @@ class CacheConfig:
|
||||
"cpu_kvcache_space_bytes",
|
||||
"mamba_page_size_padded",
|
||||
"user_specified_block_size",
|
||||
"user_specified_mamba_block_size",
|
||||
"_block_size_resolved",
|
||||
# Post-init/derived counters
|
||||
"num_gpu_blocks",
|
||||
@@ -214,6 +217,8 @@ class CacheConfig:
|
||||
object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE)
|
||||
else:
|
||||
object.__setattr__(self, "user_specified_block_size", True)
|
||||
if self.mamba_block_size is not None:
|
||||
object.__setattr__(self, "user_specified_mamba_block_size", True)
|
||||
return self
|
||||
|
||||
@field_validator("calculate_kv_scales", mode="after")
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from copy import deepcopy
|
||||
from math import lcm
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
@@ -104,11 +99,11 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Ensure that page size of attention layers is greater than or
|
||||
equal to the mamba layers. If not, automatically set the attention
|
||||
block size to ensure that it is. If the attention page size is
|
||||
strictly greater than the mamba page size, we pad the mamba page size
|
||||
to make them equal.
|
||||
Perform early validation and setup for hybrid attention/mamba models.
|
||||
|
||||
Block size alignment with mamba page sizes is handled later by
|
||||
Platform.update_block_size_for_backend(), which runs after model
|
||||
layers are constructed and the attention backend is known.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM Config
|
||||
@@ -118,6 +113,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# Disable calculate_kv_scales for hybrid models: uninitialized
|
||||
# recurrent state corrupts scales during the calibration pass.
|
||||
# See issue: https://github.com/vllm-project/vllm/issues/37554
|
||||
|
||||
if cache_config.calculate_kv_scales:
|
||||
logger.warning(
|
||||
"Disabling calculate_kv_scales for hybrid model '%s'. "
|
||||
@@ -129,140 +125,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
)
|
||||
cache_config.calculate_kv_scales = False
|
||||
|
||||
# Save the user input before it gets modified by MambaModelConfig
|
||||
mamba_block_size = cache_config.mamba_block_size
|
||||
# Enable FULL_AND_PIECEWISE by default
|
||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||
|
||||
attention_config = vllm_config.attention_config
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# get attention page size (for 1 token)
|
||||
# Attention backend constraints:
|
||||
# - FlashAttention (FA) requires block size to be multiple of 16
|
||||
# - MLA (Multi-head Latent Attention) requires larger alignment:
|
||||
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
|
||||
# * Other MLA backends: kernel_block_size 64 alignment
|
||||
if model_config.use_mla:
|
||||
use_cutlass_mla = (
|
||||
attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
|
||||
)
|
||||
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
||||
attn_page_size_1_token = MLAAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
else:
|
||||
kernel_block_alignment_size = 16
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# get mamba page size
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=-1, # block_size doesn't matter for mamba page size
|
||||
).page_size_bytes
|
||||
|
||||
# Model may be marked as is_hybrid
|
||||
# but mamba is skipped via config,
|
||||
# return directly
|
||||
if mamba_page_size == 0:
|
||||
return
|
||||
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
|
||||
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
|
||||
# Align the block to the kernel: use lowest multiple of chunk_size
|
||||
# of attention tokens that would fit mamba_page_size:
|
||||
# e.g. for mamba page size = 788kB
|
||||
# attn_1_token = 2kB -> fits ~394 tokens
|
||||
# then round up to a multiple of 256 -> 512 tokens
|
||||
# End result:
|
||||
# attn_block_size = 512
|
||||
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
|
||||
# TODO(tdoublep): this constraint can be relaxed fairly
|
||||
# easily by changing the way we layout chunks in the
|
||||
# mamba2 kernels.
|
||||
|
||||
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
|
||||
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
||||
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
cache_config.mamba_block_size = attn_block_size
|
||||
else:
|
||||
# Without prefix caching, select minimum valid attention block size
|
||||
# to minimize mamba state padding
|
||||
|
||||
# Calculate minimum attention block size that satisfies both:
|
||||
# 1. Backend alignment requirements (kernel_block_alignment_size)
|
||||
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
|
||||
attn_block_size = kernel_block_alignment_size * cdiv(
|
||||
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
|
||||
)
|
||||
|
||||
# override attention block size if it is too small,
|
||||
# even if the user has explicitly set it
|
||||
if cache_config.block_size < attn_block_size:
|
||||
cache_config.block_size = attn_block_size
|
||||
logger.info(
|
||||
"Setting attention block size to %d tokens "
|
||||
"to ensure that attention page size is >= mamba page size.",
|
||||
attn_block_size,
|
||||
)
|
||||
|
||||
# By default, mamba block size will be set to max_model_len.
|
||||
# When enabling prefix caching and using align mamba cache
|
||||
# mode, we align mamba block size to the block size as the
|
||||
# basic granularity for prefix caching.
|
||||
if cache_config.mamba_cache_mode == "align":
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
||||
|
||||
assert attn_page_size >= mamba_page_size
|
||||
|
||||
if attn_page_size == mamba_page_size:
|
||||
# don't need to pad mamba page size
|
||||
return
|
||||
|
||||
# pad mamba page size to exactly match attention
|
||||
if (
|
||||
cache_config.mamba_page_size_padded is None
|
||||
or cache_config.mamba_page_size_padded != attn_page_size
|
||||
):
|
||||
cache_config.mamba_page_size_padded = attn_page_size
|
||||
mamba_padding_pct = (
|
||||
100 * (attn_page_size - mamba_page_size) / mamba_page_size
|
||||
)
|
||||
logger.info(
|
||||
"Padding mamba page size by %.2f%% to ensure "
|
||||
"that mamba page size and attention page size are "
|
||||
"exactly equal.",
|
||||
mamba_padding_pct,
|
||||
)
|
||||
|
||||
|
||||
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
|
||||
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.selector import AttentionSelectorConfig
|
||||
else:
|
||||
FlexibleArgumentParser = object
|
||||
@@ -424,29 +425,11 @@ class Platform:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Ensure block_size is compatible with the attention backend.
|
||||
"""
|
||||
from vllm.config.cache import CacheConfig
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config.user_specified_block_size:
|
||||
# User specified --block-size; keep it.
|
||||
return
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
# model_config may be None during testing.
|
||||
# Skip hybrid models — their block_size is managed by
|
||||
# HybridAttentionMambaModelConfig.
|
||||
if model_config is None or model_config.is_hybrid:
|
||||
cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE
|
||||
return
|
||||
|
||||
from vllm.config.vllm import (
|
||||
get_layers_from_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
def _find_non_ssm_backend(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> "type[AttentionBackend] | None":
|
||||
"""Find the first non-SSM attention backend from model layers."""
|
||||
from vllm.config.vllm import get_layers_from_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import (
|
||||
AttentionLayerBase,
|
||||
)
|
||||
@@ -455,23 +438,181 @@ class Platform:
|
||||
vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
if not attn_layers:
|
||||
cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE
|
||||
for layer in attn_layers.values():
|
||||
b = layer.get_attn_backend()
|
||||
if not b.is_ssm():
|
||||
return b
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Ensure block_size is compatible with the attention backend.
|
||||
For hybrid models, also aligns block_size with mamba page sizes.
|
||||
"""
|
||||
from vllm.config.cache import CacheConfig
|
||||
from vllm.config.vllm import set_current_vllm_config
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
# model_config may be None during testing.
|
||||
if not model_config:
|
||||
return
|
||||
|
||||
first_layer = next(iter(attn_layers.values()))
|
||||
backend_cls = first_layer.get_attn_backend()
|
||||
backend_cls = cls._find_non_ssm_backend(vllm_config)
|
||||
if backend_cls is None:
|
||||
return
|
||||
|
||||
# Phase 1: Pick block size from backend (skip if user set --block-size)
|
||||
if not cache_config.user_specified_block_size:
|
||||
with set_current_vllm_config(vllm_config):
|
||||
preferred = backend_cls.get_preferred_block_size(
|
||||
CacheConfig.DEFAULT_BLOCK_SIZE
|
||||
)
|
||||
if preferred != CacheConfig.DEFAULT_BLOCK_SIZE:
|
||||
logger.info(
|
||||
"Setting kv cache block size to %d for %s backend.",
|
||||
preferred,
|
||||
backend_cls.get_name(),
|
||||
)
|
||||
cache_config.block_size = preferred
|
||||
|
||||
# Phase 2: Align block/mamba sizes for hybrid models
|
||||
# (may override user settings).
|
||||
if model_config.is_hybrid:
|
||||
cls._align_hybrid_block_size(vllm_config, backend_cls)
|
||||
|
||||
@classmethod
|
||||
def _align_hybrid_block_size(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
backend_cls: "type[AttentionBackend]",
|
||||
) -> None:
|
||||
"""
|
||||
For hybrid attention/mamba models, ensure that the attention page
|
||||
size is >= the mamba page size, and pad the mamba page size to match.
|
||||
"""
|
||||
from math import lcm
|
||||
|
||||
from vllm.config.vllm import set_current_vllm_config
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backend import MultipleOf
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
)
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# Compute attention page size for 1 token
|
||||
if model_config.use_mla:
|
||||
attn_page_size_1_token = MLAAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
else:
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
|
||||
# Compute mamba page size
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
model_config=model_config,
|
||||
)
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=-1,
|
||||
).page_size_bytes
|
||||
|
||||
if mamba_page_size == 0:
|
||||
return
|
||||
|
||||
# mamba_block_size here should either be user specified value or None
|
||||
mamba_block_size = (
|
||||
cache_config.mamba_block_size
|
||||
if cache_config.user_specified_mamba_block_size
|
||||
else None
|
||||
)
|
||||
|
||||
# Get kernel block alignment from the backend's supported sizes
|
||||
with set_current_vllm_config(vllm_config):
|
||||
preferred = backend_cls.get_preferred_block_size(
|
||||
CacheConfig.DEFAULT_BLOCK_SIZE
|
||||
kernel_block_alignment_size = max(
|
||||
min(
|
||||
s.base if isinstance(s, MultipleOf) else s
|
||||
for s in backend_cls.get_supported_kernel_block_sizes()
|
||||
),
|
||||
cache_config.block_size,
|
||||
)
|
||||
if preferred != CacheConfig.DEFAULT_BLOCK_SIZE:
|
||||
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
# With prefix caching, align to mamba chunk size for kernel perf
|
||||
# TODO(tdoublep): this constraint can be relaxed fairly
|
||||
# easily by changing the way we layout chunks in the
|
||||
# mamba2 kernels.
|
||||
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
|
||||
assert base_chunk_size is not None
|
||||
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
||||
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
cache_config.mamba_block_size = attn_block_size
|
||||
else:
|
||||
# Without prefix caching, use minimum block size that satisfies
|
||||
# both backend alignment and mamba page size compatibility
|
||||
attn_block_size = kernel_block_alignment_size * cdiv(
|
||||
mamba_page_size,
|
||||
kernel_block_alignment_size * attn_page_size_1_token,
|
||||
)
|
||||
|
||||
if cache_config.block_size < attn_block_size:
|
||||
cache_config.block_size = attn_block_size
|
||||
logger.info(
|
||||
"Setting kv cache block size to %d for %s backend.",
|
||||
preferred,
|
||||
backend_cls.get_name(),
|
||||
"Setting attention block size to %d tokens "
|
||||
"to ensure that attention page size is >= mamba page size.",
|
||||
attn_block_size,
|
||||
)
|
||||
|
||||
if cache_config.mamba_cache_mode == "align":
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
|
||||
# Pad mamba page size to exactly match attention page size
|
||||
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
||||
assert attn_page_size >= mamba_page_size
|
||||
|
||||
if attn_page_size == mamba_page_size:
|
||||
return
|
||||
|
||||
if (
|
||||
cache_config.mamba_page_size_padded is None
|
||||
or cache_config.mamba_page_size_padded != attn_page_size
|
||||
):
|
||||
cache_config.mamba_page_size_padded = attn_page_size
|
||||
mamba_padding_pct = (
|
||||
100 * (attn_page_size - mamba_page_size) / mamba_page_size
|
||||
)
|
||||
logger.info(
|
||||
"Padding mamba page size by %.2f%% to ensure "
|
||||
"that mamba page size and attention page size are "
|
||||
"exactly equal.",
|
||||
mamba_padding_pct,
|
||||
)
|
||||
cache_config.block_size = preferred
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
|
||||
@@ -160,11 +160,7 @@ class XPUPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
cache_config = vllm_config.cache_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
# in V1(or with chunked prefill) block_size is 64
|
||||
if cache_config and not cache_config.user_specified_block_size:
|
||||
cache_config.block_size = 64
|
||||
|
||||
# lazy import to avoid circular import
|
||||
from vllm.config import CUDAGraphMode
|
||||
@@ -221,12 +217,6 @@ class XPUPlatform(Platform):
|
||||
# ref. https://openucx.readthedocs.io/en/master/faq.html
|
||||
os.environ["UCX_MEMTYPE_CACHE"] = "n"
|
||||
|
||||
@classmethod
|
||||
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
|
||||
# TODO: XPU still sets block_size in check_and_update_config.
|
||||
# Move that logic here so block_size is chosen by the backend.
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@@ -311,6 +311,10 @@ class AttentionBackend(ABC):
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_ssm(cls) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class AttentionMetadata:
|
||||
pass
|
||||
|
||||
@@ -43,6 +43,7 @@ from vllm.config import (
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.v1.attention.backend import (
|
||||
@@ -90,6 +91,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@classmethod
|
||||
def get_preferred_block_size(cls, default_block_size: int) -> int:
|
||||
if current_platform.is_xpu():
|
||||
return max(default_block_size, 64)
|
||||
return super().get_preferred_block_size(default_block_size)
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN"
|
||||
|
||||
@@ -31,6 +31,10 @@ class GDNAttentionBackend(AttentionBackend):
|
||||
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
|
||||
return GDNAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def is_ssm(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class GDNAttentionMetadata:
|
||||
|
||||
@@ -27,6 +27,10 @@ class LinearAttentionBackend(AttentionBackend):
|
||||
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
|
||||
return LinearAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def is_ssm(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearAttentionMetadata:
|
||||
|
||||
@@ -20,6 +20,10 @@ class Mamba1AttentionBackend(AttentionBackend):
|
||||
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
|
||||
return Mamba1AttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def is_ssm(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba1AttentionMetadata(BaseMambaAttentionMetadata):
|
||||
|
||||
@@ -96,6 +96,10 @@ class Mamba2AttentionBackend(AttentionBackend):
|
||||
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
|
||||
return Mamba2AttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def is_ssm(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
|
||||
|
||||
@@ -18,6 +18,10 @@ class ShortConvAttentionBackend(AttentionBackend):
|
||||
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
|
||||
return ShortConvAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def is_ssm(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShortConvAttentionMetadata(BaseMambaAttentionMetadata):
|
||||
|
||||
Reference in New Issue
Block a user