[v1] Add PrefixLM support to FlexAttention backend (#27938)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -133,6 +133,7 @@ class CpuPlatform(Platform):
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
||||
|
||||
@@ -233,6 +233,20 @@ class CudaPlatformBase(Platform):
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse backend."
|
||||
)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
# Note: model_config may be None during testing
|
||||
if (
|
||||
model_config is not None
|
||||
and model_config.is_mm_prefix_lm
|
||||
and scheduler_config.is_multimodal_model
|
||||
and not scheduler_config.disable_chunked_mm_input
|
||||
):
|
||||
logger.warning(
|
||||
"Forcing --disable_chunked_mm_input for models "
|
||||
"with multimodal-bidirectional attention."
|
||||
)
|
||||
scheduler_config.disable_chunked_mm_input = True
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(
|
||||
cls, device: torch.types.Device | None = None
|
||||
@@ -268,6 +282,7 @@ class CudaPlatformBase(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
) -> tuple[
|
||||
@@ -289,6 +304,7 @@ class CudaPlatformBase(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
)
|
||||
@@ -312,6 +328,7 @@ class CudaPlatformBase(Platform):
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
if attn_type is None:
|
||||
@@ -332,6 +349,7 @@ class CudaPlatformBase(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
)
|
||||
@@ -356,6 +374,7 @@ class CudaPlatformBase(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
)
|
||||
|
||||
@@ -239,6 +239,7 @@ class Platform:
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
|
||||
@@ -216,6 +216,7 @@ class RocmPlatform(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
@@ -62,8 +62,9 @@ class TpuPlatform(Platform):
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
if use_sparse:
|
||||
|
||||
@@ -48,7 +48,8 @@ class XPUPlatform(Platform):
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
Reference in New Issue
Block a user