[V1] [Hybrid] Validate compatibility of attention backend batch reordering at init time (#21557)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional, Union
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
@@ -21,17 +21,17 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import use_trtllm_decode_attention
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
get_kv_cache_layout, get_per_layer_parameters,
|
||||
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
|
||||
split_decodes_and_prefills)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -179,6 +179,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.device = device
|
||||
@@ -239,12 +241,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = torch.empty(
|
||||
|
||||
@@ -2,21 +2,17 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||
chunk_size: int,
|
||||
@@ -87,6 +83,8 @@ class Mamba2AttentionMetadata:
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
@@ -95,12 +93,6 @@ class Mamba2AttentionMetadataBuilder(
|
||||
assert self.chunk_size is not None, (
|
||||
"chunk_size needs to be set in the model config for Mamba2 models")
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
@@ -190,7 +190,7 @@ return curr_o @ W_O
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
|
||||
from typing import ClassVar, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -210,10 +210,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.utils.flashinfer import has_nvidia_artifactory
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
get_per_layer_parameters, infer_global_hyperparameters,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
try:
|
||||
@@ -233,10 +234,6 @@ try:
|
||||
except ImportError:
|
||||
flashinfer_available = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CUDNN_WORKSPACE_SIZE = 12800
|
||||
@@ -403,6 +400,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
@@ -559,12 +557,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
prefill.prefill_main = self._fi_prefill_main
|
||||
prefill.prefill_chunks = self._fi_prefill_chunks
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor):
|
||||
return MLACommonDecodeMetadata(
|
||||
|
||||
@@ -251,9 +251,6 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def reorder_batch(self, input_batch, scheduler_output) -> bool:
|
||||
return False
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata):
|
||||
self.total_tokens = self.model_config.max_model_len \
|
||||
|
||||
@@ -167,6 +167,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention.
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
reorder_batch_threshold: ClassVar[Optional[int]] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
@@ -221,14 +225,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""
|
||||
This method can reorder the batch if desired by the backend.
|
||||
:return: Has the batch been reordered (default False).
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_kv_cache_layout():
|
||||
|
||||
Reference in New Issue
Block a user