[V1] [Hybrid] Validate compatibility of attention backend batch reordering at init time (#21557)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-08-02 14:29:40 +02:00
committed by GitHub
parent f5d0f4784f
commit 4abfd8796f
7 changed files with 96 additions and 72 deletions

View File

@@ -4,7 +4,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Union
from typing import ClassVar, Optional, Union
import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
@@ -21,17 +21,17 @@ from vllm.logger import init_logger
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_decode_attention
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout,
get_per_layer_parameters,
infer_global_hyperparameters,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
logger = init_logger(__name__)
@@ -179,6 +179,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
reorder_batch_threshold: ClassVar[int] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.device = device
@@ -239,12 +241,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int32,
device=self.device)
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(

View File

@@ -2,21 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
@@ -87,6 +83,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
@@ -95,12 +93,6 @@ class Mamba2AttentionMetadataBuilder(
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,

View File

@@ -190,7 +190,7 @@ return curr_o @ W_O
import functools
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
from typing import ClassVar, Generic, Optional, TypeVar, Union
import torch
@@ -210,10 +210,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
get_per_layer_parameters, infer_global_hyperparameters,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
get_per_layer_parameters,
infer_global_hyperparameters,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
try:
@@ -233,10 +234,6 @@ try:
except ImportError:
flashinfer_available = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
logger = init_logger(__name__)
CUDNN_WORKSPACE_SIZE = 12800
@@ -403,6 +400,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
reorder_batch_threshold: ClassVar[int] = 1
def __init__(self,
kv_cache_spec: AttentionSpec,
@@ -559,12 +557,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill.prefill_main = self._fi_prefill_main
prefill.prefill_chunks = self._fi_prefill_chunks
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(

View File

@@ -251,9 +251,6 @@ class AiterFlashAttentionMetadataBuilder(
self.aot_sliding_window: Optional[tuple[int, int]] = None
self.total_tokens: int = 0
def reorder_batch(self, input_batch, scheduler_output) -> bool:
return False
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
self.total_tokens = self.model_config.max_model_len \

View File

@@ -167,6 +167,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: ClassVar[Optional[int]] = None
@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
@@ -221,14 +225,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
) -> bool:
return False
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
"""
This method can reorder the batch if desired by the backend.
:return: Has the batch been reordered (default False).
"""
return False
@functools.lru_cache
def get_kv_cache_layout():