[Attention] Refactor attention metadata builder interface (#20466)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -12,13 +12,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
try:
|
||||
@@ -316,19 +315,21 @@ class TorchSDPAMetadata(AttentionMetadata):
|
||||
|
||||
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable) -> None:
|
||||
self.runner = runner
|
||||
self.block_table = block_table
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device) -> None:
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
# For reorder
|
||||
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
dtype=np.int64)
|
||||
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
dtype=np.int64)
|
||||
self.reorder_prompt_req_index_list = np.empty(
|
||||
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
|
||||
self.reorder_decode_req_index_list = np.empty(
|
||||
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
|
||||
self.num_prompt_req: int = 0
|
||||
|
||||
self.seq_start_loc_cpu = torch.zeros(
|
||||
runner.max_num_reqs + 1,
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
@@ -378,15 +379,15 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
return True
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> TorchSDPAMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
runner = self.runner
|
||||
block_table = self.block_table
|
||||
seq_lens_np = runner.seq_lens_np[:num_reqs]
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
seq_lens_np = seq_lens_cpu.numpy()
|
||||
num_prompt_req = self.num_prompt_req
|
||||
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
|
||||
) if num_prompt_req > 0 else 0
|
||||
@@ -394,34 +395,36 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
) if num_prompt_req < num_reqs else 0
|
||||
self.seq_start_loc_np[0] = 0
|
||||
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
|
||||
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
|
||||
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
|
||||
) - num_prefill_tokens
|
||||
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
|
||||
block_table_tensor = block_table.get_device_tensor()
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
|
||||
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
|
||||
num_prefill_tokens)
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping.long()
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
|
||||
attn_metadata = TorchSDPAMetadata(
|
||||
num_prefills=num_prompt_req,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
# to ensure inference when chunked_prefill is disabled
|
||||
seq_lens=runner.seq_lens_cpu[:num_reqs].tolist(),
|
||||
seq_lens_tensor=runner.
|
||||
seq_lens_cpu[num_prompt_req:num_reqs], # decode
|
||||
seq_lens=seq_lens_cpu.tolist(),
|
||||
seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode
|
||||
max_decode_seq_len=max_decode_seq_len, # decode
|
||||
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
|
||||
chunked_prefill=self.runner.scheduler_config.
|
||||
chunked_prefill_enabled,
|
||||
chunked_prefill=self.scheduler_config.chunked_prefill_enabled,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_prefill_seq_len,
|
||||
prefill_query_start_loc=runner.
|
||||
query_start_loc_cpu[:num_prompt_req + 1], # prefill
|
||||
prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req +
|
||||
1], # prefill
|
||||
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
|
||||
1], # prefill
|
||||
prefill_block_tables=block_table_tensor[:
|
||||
num_prompt_req], # prefill
|
||||
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
|
||||
1], # for logits index
|
||||
query_start_loc=query_start_loc_cpu[:num_reqs +
|
||||
1], # for logits index
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -29,10 +29,6 @@ from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -162,29 +158,30 @@ class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
model_config = runner.model_config
|
||||
compilation_config = runner.vllm_config.compilation_config
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.device = device
|
||||
|
||||
self.runner = runner
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
runner.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(
|
||||
self.parallel_config)
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.use_full_cuda_graph = compilation_config.full_cuda_graph
|
||||
self.use_full_cuda_graph = self.compilation_config.full_cuda_graph
|
||||
if self.use_full_cuda_graph:
|
||||
if not self.aot_schedule:
|
||||
raise ValueError(
|
||||
"AoT scheduling is required for full cuda graph.")
|
||||
capture_sizes = compilation_config.cudagraph_capture_sizes
|
||||
capture_sizes = self.compilation_config.cudagraph_capture_sizes
|
||||
if not capture_sizes:
|
||||
raise ValueError(
|
||||
"cudagraph_capture_sizes should not be None when "
|
||||
@@ -198,9 +195,9 @@ class FlashAttentionMetadataBuilder(
|
||||
"full cuda graph.")
|
||||
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
self.runner.max_num_reqs + 1,
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device,
|
||||
device=self.device,
|
||||
)
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
@@ -211,28 +208,27 @@ class FlashAttentionMetadataBuilder(
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
def build(
|
||||
self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata
|
||||
) -> FlashAttentionMetadata:
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> FlashAttentionMetadata:
|
||||
"""
|
||||
fast_build disables AOT scheduling, used when there will be few
|
||||
iterations i.e. spec-decode
|
||||
"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
# the overhead of the aot schedule is not worth it for spec-decode
|
||||
aot_schedule = self.aot_schedule and not fast_build
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
@@ -240,19 +236,20 @@ class FlashAttentionMetadataBuilder(
|
||||
# constant for all layers to. We have to populate this on the first
|
||||
# build() call so the layers are constructed (cannot populate)
|
||||
# in __init__.
|
||||
if self.aot_schedule:
|
||||
if aot_schedule:
|
||||
sliding_window_configs = _get_sliding_window_configs(
|
||||
self.runner.vllm_config)
|
||||
self.vllm_config)
|
||||
if len(sliding_window_configs) == 1:
|
||||
sliding_window_config = sliding_window_configs.pop()
|
||||
if sliding_window_config is not None:
|
||||
self.aot_sliding_window = sliding_window_config
|
||||
elif len(sliding_window_configs) > 1:
|
||||
self.aot_schedule = False
|
||||
aot_schedule = False
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
if self.aot_schedule:
|
||||
if aot_schedule:
|
||||
return get_scheduler_metadata(
|
||||
batch_size=batch_size,
|
||||
max_seqlen_q=max_query_len,
|
||||
@@ -271,19 +268,19 @@ class FlashAttentionMetadataBuilder(
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
if self.model_config.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
self.model_config.attention_chunk_size,
|
||||
query_start_loc_cpu.numpy(),
|
||||
seq_lens_cpu.numpy(),
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
self.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
self.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max()
|
||||
local_max_seq_len = virt_k_seqlens_np.max()
|
||||
local_scheduler_metadata = schedule(
|
||||
@@ -308,14 +305,12 @@ class FlashAttentionMetadataBuilder(
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
device=self.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
device=self.device)
|
||||
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
|
||||
self.device, non_blocking=True)
|
||||
prefix_scheduler_metadata = schedule(
|
||||
batch_size=1,
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
|
||||
@@ -15,22 +15,20 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
PerLayerParameters,
|
||||
get_kv_cache_layout,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
|
||||
get_kv_cache_layout, get_per_layer_parameters,
|
||||
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
|
||||
@@ -226,9 +224,9 @@ class FlashInferMetadata:
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
self.runner = runner
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
self.device = device
|
||||
self._workspace_buffer = None
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
self._decode_wrapper = None # Wrapper for decode
|
||||
@@ -237,75 +235,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# Global hyperparameters shared by all attention layers
|
||||
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||
|
||||
self.vllm_config = runner.vllm_config
|
||||
self.vllm_config = vllm_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the decode run only supports num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.runner.device)
|
||||
device=self.device)
|
||||
return self._workspace_buffer
|
||||
|
||||
def _get_prefill_wrapper(self):
|
||||
@@ -316,10 +261,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def _get_decode_wrapper(self):
|
||||
if self._decode_wrapper is None:
|
||||
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config))
|
||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config)
|
||||
num_qo_heads = (
|
||||
self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config))
|
||||
num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
|
||||
self.vllm_config.parallel_config)
|
||||
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
@@ -334,7 +280,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
2, self._get_workspace_buffer(), get_kv_cache_layout())
|
||||
return self._cascade_wrapper
|
||||
|
||||
def _plan(self, attn_metadata: FlashInferMetadata):
|
||||
def _plan(self, num_prefills: int, num_decodes: int,
|
||||
attn_metadata: FlashInferMetadata):
|
||||
if self.global_hyperparameters is None:
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
|
||||
@@ -369,16 +316,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
if self._num_prefills > 0:
|
||||
if num_prefills > 0:
|
||||
# Decodes are first so prefills start after the last decode
|
||||
prefill_start = self._num_decodes
|
||||
prefill_start = num_decodes
|
||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
||||
assert attn_metadata.qo_indptr[prefill_start:].shape[
|
||||
0] == self._num_prefills + 1
|
||||
0] == num_prefills + 1
|
||||
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
|
||||
0] == self._num_prefills + 1
|
||||
0] == num_prefills + 1
|
||||
assert attn_metadata.paged_kv_last_page_len[
|
||||
prefill_start:].shape[0] == self._num_prefills
|
||||
prefill_start:].shape[0] == num_prefills
|
||||
# Since prefill_wrapper.run() will be called with
|
||||
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
||||
# to be relative to the start of the prefill queries.
|
||||
@@ -402,17 +349,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
)
|
||||
|
||||
if self._num_decodes > 0:
|
||||
if num_decodes > 0:
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
self._num_decodes, attn_metadata.max_seq_len,
|
||||
num_decodes, attn_metadata.max_seq_len,
|
||||
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
attn_metadata.decode_wrapper.plan(
|
||||
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
|
||||
attn_metadata.paged_kv_indptr[:num_decodes + 1],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len[:self.
|
||||
_num_decodes],
|
||||
attn_metadata.paged_kv_last_page_len[:num_decodes],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
@@ -427,22 +373,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
)
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> FlashInferMetadata:
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
assert (self._num_decode_tokens +
|
||||
self._num_prefill_tokens == num_actual_tokens)
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
device = self.runner.device
|
||||
device = self.device
|
||||
qo_indptr = common_attn_metadata.query_start_loc
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
|
||||
@@ -487,7 +431,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
cache_dtype = self.runner.cache_config.cache_dtype
|
||||
cache_dtype = self.cache_config.cache_dtype
|
||||
if cache_dtype.startswith("fp8"):
|
||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
cache_dtype)
|
||||
@@ -499,17 +443,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
num_qo_heads=self.runner.num_query_heads,
|
||||
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config),
|
||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||
head_dim=self.kv_cache_spec.head_size,
|
||||
page_size=page_size,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
q_data_type=self.runner.dtype,
|
||||
slot_mapping=slot_mapping,
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=self._num_prefills,
|
||||
num_prefill_tokens=self._num_prefill_tokens,
|
||||
q_data_type=self.vllm_config.model_config.dtype,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
use_cascade=use_cascade,
|
||||
shared_qo_indptr=shared_qo_indptr,
|
||||
shared_kv_page_indptr=shared_kv_page_indptr,
|
||||
@@ -521,12 +466,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
)
|
||||
|
||||
self._plan(attn_metadata)
|
||||
self._plan(num_prefills, num_decodes, attn_metadata)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
|
||||
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
# kv cache dtype to something different from query dtype.
|
||||
return False
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""Attention layer with FlashAttention."""
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||
@@ -14,18 +14,15 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
create_block_mask_compiled = torch.compile(create_block_mask,
|
||||
fullgraph=True,
|
||||
mode="reduce-overhead")
|
||||
@@ -261,36 +258,34 @@ class FlexAttentionMetadata:
|
||||
class FlexAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlexAttentionMetadata]):
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
model_config = runner.model_config
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
self.runner = runner
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
runner.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config)
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
self.device = device
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> FlexAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
cu_prefix_query_lens = None
|
||||
@@ -300,17 +295,15 @@ class FlexAttentionMetadataBuilder(
|
||||
raise NotImplementedError("Not yet my friend")
|
||||
|
||||
block_size = self.kv_cache_spec.block_size
|
||||
max_possible_seq_len = self.runner.model_config.max_model_len
|
||||
total_cache_tokens = (self.runner.cache_config.num_gpu_blocks *
|
||||
block_size)
|
||||
max_possible_seq_len = self.model_config.max_model_len
|
||||
total_cache_tokens = self.cache_config.num_gpu_blocks * block_size
|
||||
|
||||
inverse_block_table = physical_to_logical_mapping(
|
||||
block_table_tensor, self.runner.cache_config.num_gpu_blocks)
|
||||
block_table_tensor, self.cache_config.num_gpu_blocks)
|
||||
|
||||
# Get the original offset tensor
|
||||
offset_tensor = torch.tensor(
|
||||
self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
out = FlexAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
|
||||
@@ -7,15 +7,15 @@ from typing import TYPE_CHECKING, Optional
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import MambaSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||
@@ -87,80 +87,24 @@ class Mamba2AttentionMetadata:
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
|
||||
block_table: BlockTable):
|
||||
self.runner = runner
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
self.chunk_size = runner.vllm_config.model_config.get_mamba_chunk_size(
|
||||
)
|
||||
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||
assert self.chunk_size is not None, (
|
||||
"chunk_size needs to be set in the model config for Mamba2 models")
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# NOTE (Chen): Copied from MLACommonMetadataBuilder and
|
||||
# FlashInferMetadataBuilder. Should be refactored later to avoid code
|
||||
# duplication of these 3 functions.
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the decode run only supports num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> Mamba2AttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
@@ -172,29 +116,31 @@ class Mamba2AttentionMetadataBuilder(
|
||||
has_initial_states = None
|
||||
prep_initial_states = False
|
||||
|
||||
state_indices_tensor = self.block_table.block_table[:num_reqs, 0]
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=1))
|
||||
|
||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||
if self._num_prefills > 0:
|
||||
if num_prefills > 0:
|
||||
#[batch,]
|
||||
has_initial_states_cpu = (
|
||||
self.runner.input_batch.
|
||||
num_computed_tokens_cpu_tensor[num_reqs -
|
||||
self._num_prefills:num_reqs]
|
||||
> 0)
|
||||
common_attn_metadata.
|
||||
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
|
||||
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
||||
has_initial_states = has_initial_states_cpu.to(
|
||||
query_start_loc.device)
|
||||
|
||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||
-self._num_prefills - 1:] - self._num_decode_tokens
|
||||
-num_prefills - 1:] - num_decode_tokens
|
||||
|
||||
seq_idx = torch.repeat_interleave(
|
||||
torch.arange(self._num_prefills,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc_p.device),
|
||||
query_start_loc_p.diff(),
|
||||
output_size=self._num_prefill_tokens)
|
||||
seq_idx = torch.repeat_interleave(torch.arange(
|
||||
num_prefills,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc_p.device),
|
||||
query_start_loc_p.diff(),
|
||||
output_size=num_prefill_tokens)
|
||||
seq_idx.unsqueeze_(0)
|
||||
|
||||
# We compute metadata for chunked prefill once at the top level
|
||||
@@ -204,13 +150,13 @@ class Mamba2AttentionMetadataBuilder(
|
||||
chunk_indices, chunk_offsets = (
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc_p, self.chunk_size,
|
||||
self._num_prefill_tokens))
|
||||
num_prefill_tokens))
|
||||
|
||||
attn_metadata = Mamba2AttentionMetadata(
|
||||
num_prefills=self._num_prefills,
|
||||
num_prefill_tokens=self._num_prefill_tokens,
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
has_initial_states=has_initial_states,
|
||||
|
||||
@@ -202,18 +202,18 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
get_per_layer_parameters, infer_global_hyperparameters,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
@@ -235,7 +235,6 @@ except ImportError:
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -406,22 +405,23 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner: "GPUModelRunner",
|
||||
kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[type[M]] = None):
|
||||
self.metadata_cls = metadata_cls \
|
||||
if metadata_cls is not None else MLACommonMetadata
|
||||
self.runner = runner
|
||||
scheduler_config = runner.scheduler_config
|
||||
model_config = runner.model_config
|
||||
cache_config = runner.cache_config
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
self.num_heads = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.mla_dims = get_mla_dims(model_config)
|
||||
self.aot_schedule = current_platform.is_cuda()
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.device = device
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
self.num_heads = self.model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.aot_schedule = current_platform.is_cuda()
|
||||
|
||||
# Dont try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
@@ -432,7 +432,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(
|
||||
8 * model_config.max_model_len, 4 *
|
||||
8 * self.model_config.max_model_len, 4 *
|
||||
scheduler_config.max_num_seqs * cache_config.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
@@ -447,13 +447,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
scheduler_config.max_num_seqs * cache_config.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
model_config.get_head_size()),
|
||||
dtype=model_config.dtype,
|
||||
device=runner.device,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.block_table = block_table
|
||||
|
||||
self._use_cudnn_prefill = use_cudnn_prefill()
|
||||
self._use_fi_prefill = use_flashinfer_prefill()
|
||||
self.prefill_metadata_cls = (
|
||||
@@ -465,7 +463,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
self._workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=runner.device)
|
||||
device=device)
|
||||
|
||||
self._fi_prefill_main: Optional[
|
||||
BatchPrefillWithRaggedKVCacheWrapper] = None
|
||||
@@ -473,13 +471,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
BatchPrefillWithRaggedKVCacheWrapper] = []
|
||||
|
||||
self._global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(runner.vllm_config, MLACommonImpl))
|
||||
get_per_layer_parameters(vllm_config, MLACommonImpl))
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
self.cudnn_workspace = torch.empty(
|
||||
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
|
||||
dtype=torch.int8,
|
||||
device=runner.device,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
|
||||
@@ -505,7 +503,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
assert num_chunks <= len(self._fi_prefill_chunks)
|
||||
|
||||
# In MLA, the non-latent num_qo_heads == num_kv_heads
|
||||
num_qo_heads = self.runner.num_query_heads
|
||||
num_qo_heads = self.num_heads
|
||||
num_kv_heads = num_qo_heads
|
||||
|
||||
# Sanity: Verify that num_kv_heads == 1 since it is latent space
|
||||
@@ -531,7 +529,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
sm_scale=self._global_hyperparameters.sm_scale,
|
||||
window_left=self._global_hyperparameters.window_left,
|
||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=self.runner.dtype,
|
||||
q_data_type=self.model_config.dtype,
|
||||
kv_data_type=self.kv_cache_spec.dtype,
|
||||
)
|
||||
|
||||
@@ -552,7 +550,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
window_left=self._global_hyperparameters.window_left,
|
||||
logits_soft_cap=self._global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=self.runner.dtype,
|
||||
q_data_type=self.model_config.dtype,
|
||||
kv_data_type=self.kv_cache_spec.dtype,
|
||||
)
|
||||
|
||||
@@ -561,63 +559,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor):
|
||||
@@ -639,49 +583,50 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
# Update state usually set in reorder_batch.
|
||||
self._num_decodes = m.num_reqs
|
||||
self._num_decode_tokens = m.num_actual_tokens
|
||||
self._num_prefills = 0
|
||||
self._num_prefill_tokens = 0
|
||||
return self.build(0, m)
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> M:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.runner.device
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
device = self.device
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
prefill_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
reqs_start = self._num_decodes # prefill_start
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
context_lens_cpu = self.runner.input_batch.\
|
||||
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
||||
num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu -
|
||||
query_seq_lens_cpu)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
chunked_context_metadata = None
|
||||
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
||||
if self.chunked_prefill_enabled and num_prefills > 0 \
|
||||
and max_context_len_cpu > 0:
|
||||
# NOTE: it is recommend you read the `Chunked Prefill` section
|
||||
# in the comment at the top of the file before trying to
|
||||
@@ -712,14 +657,14 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
# of `to_list`.
|
||||
chunk_starts = \
|
||||
torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, self._num_prefills) \
|
||||
.unsqueeze(1).expand(-1, num_prefills) \
|
||||
* max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
|
||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
self._num_prefills + 1,
|
||||
num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(chunk_seq_lens,
|
||||
@@ -762,28 +707,28 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
prefill_metadata.cudnn_workspace = self.cudnn_workspace
|
||||
|
||||
decode_metadata = None
|
||||
if self._num_decodes > 0:
|
||||
if num_decodes > 0:
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
|
||||
seq_lens=seq_lens[:self._num_decodes],
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=seq_lens[:num_decodes],
|
||||
)
|
||||
|
||||
attn_metadata = self.metadata_cls(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
# MLACommonMetadata Chunk prefill specific
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=self._num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
if self._use_fi_prefill and self._num_prefills > 0:
|
||||
if self._use_fi_prefill and num_prefills > 0:
|
||||
assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
|
||||
self._build_fi_prefill_wrappers(attn_metadata.prefill)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionType,
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
@@ -18,7 +19,6 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -56,12 +56,13 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
||||
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
@@ -75,7 +76,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
if self.runner.full_cuda_graph:
|
||||
if self.compilation_config.full_cuda_graph:
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_tile_scheduler_metadata is None:
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
|
||||
@@ -8,6 +8,8 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
@@ -16,7 +18,6 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -65,24 +66,26 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # decode only
|
||||
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata)
|
||||
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
||||
"only supports block size 1."
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size)
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
|
||||
# Preparing persistent buffers
|
||||
if self.runner.full_cuda_graph:
|
||||
device = self.runner.device
|
||||
max_num_reqs = self.runner.max_num_reqs
|
||||
if vllm_config.compilation_config.full_cuda_graph:
|
||||
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
block_table.get_device_tensor().numel(
|
||||
), # max num pages possible
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.paged_kv_indices = torch.zeros(max_num_pages,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
@@ -96,7 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
device = self.runner.device
|
||||
device = self.device
|
||||
num_reqs = seq_lens.size(0)
|
||||
|
||||
mask = (torch.arange(block_table_tensor.size(1),
|
||||
dtype=block_table_tensor.dtype,
|
||||
@@ -113,8 +117,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
||||
])
|
||||
|
||||
if self.runner.full_cuda_graph:
|
||||
num_reqs = self._num_decodes
|
||||
if self.compilation_config.full_cuda_graph:
|
||||
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
|
||||
@@ -137,7 +140,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
|
||||
else:
|
||||
qo_indptr = torch.arange(0,
|
||||
self._num_decodes + 1,
|
||||
num_reqs + 1,
|
||||
step=1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with AiterFlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,18 +10,13 @@ from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if current_platform.is_rocm():
|
||||
import aiter
|
||||
@@ -172,54 +167,49 @@ logger = init_logger(__name__)
|
||||
|
||||
class AiterFlashAttentionMetadataBuilder:
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
model_config = runner.model_config
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.device = device
|
||||
|
||||
self.runner = runner
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
runner.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(
|
||||
self.parallel_config)
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
def reorder_batch(self, input_batch, scheduler_output) -> bool:
|
||||
return False
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum())
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
device=self.device)
|
||||
torch.cumsum(seq_lens,
|
||||
dim=0,
|
||||
dtype=cu_seq_lens.dtype,
|
||||
@@ -231,21 +221,21 @@ class AiterFlashAttentionMetadataBuilder:
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
if self.model_config.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
self.model_config.attention_chunk_size,
|
||||
query_start_loc_cpu.numpy(),
|
||||
seq_lens_cpu.numpy(),
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
self.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_max_query_len = int(seqlens_q_local_np.max())
|
||||
local_max_seq_len = int(virt_k_seqlens_np.max())
|
||||
self.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max().item()
|
||||
local_max_seq_len = virt_k_seqlens_np.max().item()
|
||||
local_scheduler_metadata = schedule(
|
||||
batch_size=local_query_start_loc.shape[0] - 1,
|
||||
cu_query_lens=local_query_start_loc,
|
||||
@@ -256,12 +246,11 @@ class AiterFlashAttentionMetadataBuilder:
|
||||
|
||||
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
device=self.device)
|
||||
local_cu_seq_lens[1:] = torch.cumsum(
|
||||
torch.from_numpy(virt_k_seqlens_np).to(
|
||||
device=self.runner.device,
|
||||
dtype=torch.int32,
|
||||
non_blocking=True),
|
||||
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
|
||||
dtype=torch.int32,
|
||||
non_blocking=True),
|
||||
dim=0)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
@@ -21,10 +22,6 @@ from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -75,12 +72,21 @@ class TritonAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
self.runner = runner
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
self.device = device
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
|
||||
self.attention_chunk_size = getattr(vllm_config.scheduler_config,
|
||||
'attention_chunk_size', None)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
@@ -92,46 +98,36 @@ class TritonAttentionMetadataBuilder(
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> TritonAttentionMetadata:
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
if self.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
self.attention_chunk_size,
|
||||
common_attn_metadata.query_start_loc_cpu.numpy(),
|
||||
common_attn_metadata.seq_lens_cpu.numpy(),
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
self.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max()
|
||||
local_max_seq_len = virt_k_seqlens_np.max()
|
||||
self.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max().item()
|
||||
local_max_seq_len = virt_k_seqlens_np.max().item()
|
||||
|
||||
local_attn_metadata = TritonAttentionMetadata \
|
||||
.LocalAttentionMetadata(
|
||||
@@ -148,14 +144,13 @@ class TritonAttentionMetadataBuilder(
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
device=self.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
device=self.device)
|
||||
suffix_kv_lens = (common_attn_metadata.seq_lens_cpu -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
suffix_kv_lens = suffix_kv_lens.to(self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
|
||||
@@ -22,6 +22,7 @@ import vllm.envs as envs
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_KV_CACHE_LAYOUT_OVERRIDE = None
|
||||
@@ -32,14 +33,22 @@ class CommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_cpu: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
seq_lens_cpu: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
num_computed_tokens_cpu: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
num_actual_tokens: int
|
||||
@@ -47,6 +56,14 @@ class CommonAttentionMetadata:
|
||||
max_query_len: int
|
||||
"""Longest query in batch"""
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
def __post_init__(self):
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
self.slot_mapping[self.num_actual_tokens:].fill_(-1)
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
@@ -56,11 +73,25 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
full_cudagraph_supported: ClassVar[bool] = False
|
||||
|
||||
@abstractmethod
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
@abstractmethod
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> M:
|
||||
"""
|
||||
Central method that builds attention metadata.
|
||||
Some builders (MLA) require reorder_batch to be called prior to build.
|
||||
|
||||
Args:
|
||||
common_prefix_len: The length of the common prefix of the batch.
|
||||
common_attn_metadata: The common attention metadata.
|
||||
fast_build: The meta-data will prioritize speed of building over
|
||||
then speed at execution. Can be used for spec-decode where the
|
||||
result of a build call may only be used for few layers/iters.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -351,3 +382,108 @@ def make_local_attention_virtual_batches(
|
||||
|
||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
||||
block_table_local
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||
requests.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: CommonAttentionMetadata object containing the
|
||||
batch metadata.
|
||||
decode_threshold: The maximum query length to be considered a decode.
|
||||
|
||||
Returns:
|
||||
num_decodes: The number of decode requests.
|
||||
num_prefills: The number of prefill requests.
|
||||
num_decode_tokens: The number of tokens in the decode requests.
|
||||
num_prefill_tokens: The number of tokens in the prefill requests.
|
||||
"""
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
if max_query_len <= decode_threshold:
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
is_prefill = query_lens > decode_threshold
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||
assert torch.all(query_lens[first_prefill:] > decode_threshold)
|
||||
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
||||
num_decodes = first_prefill
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_decode_tokens = query_start_loc[first_prefill].item()
|
||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||
|
||||
|
||||
def reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput",
|
||||
decode_threshold: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Reorders the batch to split into prefill and decode requests; places all
|
||||
requests with <= decode_threshold tokens at the front of the batch.
|
||||
|
||||
Returns:
|
||||
True if the batch was modified, False otherwise.
|
||||
"""
|
||||
# We now want to reorder the batch so that the "decode" requests are at
|
||||
# the front and the "prefill" requests are at the back using the least
|
||||
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
|
||||
# requests where attention is likely memory-bound and "prefill" to mean
|
||||
# requests where attention is likely compute-bound, TODO(lucas): figure out
|
||||
# a better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if num_tokens <= decode_threshold:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
return modified_batch
|
||||
|
||||
Reference in New Issue
Block a user