[v1] Add PrefixLM support to FlexAttention backend (#27938)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-12-07 23:51:36 +08:00
committed by GitHub
parent 541a2ef892
commit b952f4d3c3
16 changed files with 173 additions and 25 deletions

View File

@@ -133,6 +133,7 @@ class CpuPlatform(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:

View File

@@ -233,6 +233,20 @@ class CudaPlatformBase(Platform):
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing
if (
model_config is not None
and model_config.is_mm_prefix_lm
and scheduler_config.is_multimodal_model
and not scheduler_config.disable_chunked_mm_input
):
logger.warning(
"Forcing --disable_chunked_mm_input for models "
"with multimodal-bidirectional attention."
)
scheduler_config.disable_chunked_mm_input = True
@classmethod
def get_current_memory_usage(
cls, device: torch.types.Device | None = None
@@ -268,6 +282,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
) -> tuple[
@@ -289,6 +304,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
)
@@ -312,6 +328,7 @@ class CudaPlatformBase(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if attn_type is None:
@@ -332,6 +349,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
)
@@ -356,6 +374,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
)

View File

@@ -239,6 +239,7 @@ class Platform:
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
"""Get the attention backend class of a device."""

View File

@@ -216,6 +216,7 @@ class RocmPlatform(Platform):
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops

View File

@@ -62,8 +62,9 @@ class TpuPlatform(Platform):
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink,
use_sparse,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if use_sparse:

View File

@@ -48,7 +48,8 @@ class XPUPlatform(Platform):
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout