[Attention] Refactor attention metadata builder interface (#20466)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-07-17 00:44:25 -04:00
committed by GitHub
parent 28a6d5423d
commit 76b494444f
18 changed files with 1441 additions and 772 deletions

View File

@@ -12,13 +12,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_input_batch import InputBatch
try:
@@ -316,19 +315,21 @@ class TorchSDPAMetadata(AttentionMetadata):
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable) -> None:
self.runner = runner
self.block_table = block_table
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device) -> None:
self.kv_cache_spec = kv_cache_spec
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
# For reorder
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
dtype=np.int64)
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
dtype=np.int64)
self.reorder_prompt_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.reorder_decode_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.num_prompt_req: int = 0
self.seq_start_loc_cpu = torch.zeros(
runner.max_num_reqs + 1,
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device="cpu",
)
@@ -378,15 +379,15 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
return True
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> TorchSDPAMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
runner = self.runner
block_table = self.block_table
seq_lens_np = runner.seq_lens_np[:num_reqs]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens_np = seq_lens_cpu.numpy()
num_prompt_req = self.num_prompt_req
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
) if num_prompt_req > 0 else 0
@@ -394,34 +395,36 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
) if num_prompt_req < num_reqs else 0
self.seq_start_loc_np[0] = 0
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
) - num_prefill_tokens
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
block_table_tensor = block_table.get_device_tensor()
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
num_prefill_tokens)
slot_mapping = common_attn_metadata.slot_mapping.long()
block_table_tensor = common_attn_metadata.block_table_tensor
attn_metadata = TorchSDPAMetadata(
num_prefills=num_prompt_req,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
# to ensure inference when chunked_prefill is disabled
seq_lens=runner.seq_lens_cpu[:num_reqs].tolist(),
seq_lens_tensor=runner.
seq_lens_cpu[num_prompt_req:num_reqs], # decode
seq_lens=seq_lens_cpu.tolist(),
seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode
max_decode_seq_len=max_decode_seq_len, # decode
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
chunked_prefill=self.runner.scheduler_config.
chunked_prefill_enabled,
chunked_prefill=self.scheduler_config.chunked_prefill_enabled,
max_query_len=max_query_len,
max_kv_len=max_prefill_seq_len,
prefill_query_start_loc=runner.
query_start_loc_cpu[:num_prompt_req + 1], # prefill
prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req +
1], # prefill
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
1], # prefill
prefill_block_tables=block_table_tensor[:
num_prompt_req], # prefill
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
1], # for logits index
query_start_loc=query_start_loc_cpu[:num_reqs +
1], # for logits index
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
)

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from typing import Any, ClassVar, Optional
import numpy as np
import torch
@@ -29,10 +29,6 @@ from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
make_local_attention_virtual_batches)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
@@ -162,29 +158,30 @@ class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
model_config = runner.model_config
compilation_config = runner.vllm_config.compilation_config
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.device = device
self.runner = runner
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph
self.use_full_cuda_graph = self.compilation_config.full_cuda_graph
if self.use_full_cuda_graph:
if not self.aot_schedule:
raise ValueError(
"AoT scheduling is required for full cuda graph.")
capture_sizes = compilation_config.cudagraph_capture_sizes
capture_sizes = self.compilation_config.cudagraph_capture_sizes
if not capture_sizes:
raise ValueError(
"cudagraph_capture_sizes should not be None when "
@@ -198,9 +195,9 @@ class FlashAttentionMetadataBuilder(
"full cuda graph.")
self.scheduler_metadata = torch.zeros(
self.runner.max_num_reqs + 1,
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device=self.runner.device,
device=self.device,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
@@ -211,28 +208,27 @@ class FlashAttentionMetadataBuilder(
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
def build(
self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata
) -> FlashAttentionMetadata:
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashAttentionMetadata:
"""
fast_build disables AOT scheduling, used when there will be few
iterations i.e. spec-decode
"""
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
# the overhead of the aot schedule is not worth it for spec-decode
aot_schedule = self.aot_schedule and not fast_build
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
@@ -240,19 +236,20 @@ class FlashAttentionMetadataBuilder(
# constant for all layers to. We have to populate this on the first
# build() call so the layers are constructed (cannot populate)
# in __init__.
if self.aot_schedule:
if aot_schedule:
sliding_window_configs = _get_sliding_window_configs(
self.runner.vllm_config)
self.vllm_config)
if len(sliding_window_configs) == 1:
sliding_window_config = sliding_window_configs.pop()
if sliding_window_config is not None:
self.aot_sliding_window = sliding_window_config
elif len(sliding_window_configs) > 1:
self.aot_schedule = False
aot_schedule = False
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.aot_schedule:
if aot_schedule:
return get_scheduler_metadata(
batch_size=batch_size,
max_seqlen_q=max_query_len,
@@ -271,19 +268,19 @@ class FlashAttentionMetadataBuilder(
# for local attention
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
if self.model_config.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
self.model_config.attention_chunk_size,
query_start_loc_cpu.numpy(),
seq_lens_cpu.numpy(),
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.runner.device, non_blocking=True)
self.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True)
self.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max()
local_max_seq_len = virt_k_seqlens_np.max()
local_scheduler_metadata = schedule(
@@ -308,14 +305,12 @@ class FlashAttentionMetadataBuilder(
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
device=self.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.runner.device)
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
self.runner.device)
device=self.device)
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
self.device, non_blocking=True)
prefix_scheduler_metadata = schedule(
batch_size=1,
cu_query_lens=cu_prefix_query_lens,

View File

@@ -15,22 +15,20 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
PerLayerParameters,
get_kv_cache_layout,
get_per_layer_parameters,
infer_global_hyperparameters)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
@@ -226,9 +224,9 @@ class FlashInferMetadata:
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
self.runner = runner
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.device = device
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode
@@ -237,75 +235,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = runner.vllm_config
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the decode run only supports num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)
device=self.device)
return self._workspace_buffer
def _get_prefill_wrapper(self):
@@ -316,10 +261,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def _get_decode_wrapper(self):
if self._decode_wrapper is None:
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
num_qo_heads = (
self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config))
num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
self.vllm_config.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
@@ -334,7 +280,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
2, self._get_workspace_buffer(), get_kv_cache_layout())
return self._cascade_wrapper
def _plan(self, attn_metadata: FlashInferMetadata):
def _plan(self, num_prefills: int, num_decodes: int,
attn_metadata: FlashInferMetadata):
if self.global_hyperparameters is None:
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
@@ -369,16 +316,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
if self._num_prefills > 0:
if num_prefills > 0:
# Decodes are first so prefills start after the last decode
prefill_start = self._num_decodes
prefill_start = num_decodes
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
assert attn_metadata.qo_indptr[prefill_start:].shape[
0] == self._num_prefills + 1
0] == num_prefills + 1
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
0] == self._num_prefills + 1
0] == num_prefills + 1
assert attn_metadata.paged_kv_last_page_len[
prefill_start:].shape[0] == self._num_prefills
prefill_start:].shape[0] == num_prefills
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
@@ -402,17 +349,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
kv_data_type=attn_metadata.kv_data_type,
)
if self._num_decodes > 0:
if num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper()
if not FlashInferBackend.use_trtllm_decode_attention(
self._num_decodes, attn_metadata.max_seq_len,
num_decodes, attn_metadata.max_seq_len,
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
attn_metadata.paged_kv_indptr[:num_decodes + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[:self.
_num_decodes],
attn_metadata.paged_kv_last_page_len[:num_decodes],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
@@ -427,22 +373,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
kv_data_type=attn_metadata.kv_data_type,
)
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashInferMetadata:
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
assert self._num_decodes + self._num_prefills == num_reqs
assert (self._num_decode_tokens +
self._num_prefill_tokens == num_actual_tokens)
page_size = self.kv_cache_spec.block_size
device = self.runner.device
device = self.device
qo_indptr = common_attn_metadata.query_start_loc
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
block_table_tensor = common_attn_metadata.block_table_tensor
block_table_bounds = (seq_lens + page_size - 1) // page_size
@@ -487,7 +431,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
cache_dtype = self.runner.cache_config.cache_dtype
cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
cache_dtype)
@@ -499,17 +443,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=self.runner.num_query_heads,
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config),
num_kv_heads=self.kv_cache_spec.num_kv_heads,
head_dim=self.kv_cache_spec.head_size,
page_size=page_size,
kv_data_type=kv_cache_dtype,
q_data_type=self.runner.dtype,
slot_mapping=slot_mapping,
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills,
num_prefill_tokens=self._num_prefill_tokens,
q_data_type=self.vllm_config.model_config.dtype,
slot_mapping=common_attn_metadata.slot_mapping,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
use_cascade=use_cascade,
shared_qo_indptr=shared_qo_indptr,
shared_kv_page_indptr=shared_kv_page_indptr,
@@ -521,12 +466,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
workspace_buffer=self._workspace_buffer,
)
self._plan(attn_metadata)
self._plan(num_prefills, num_decodes, attn_metadata)
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
# kv cache dtype to something different from query dtype.
return False

View File

@@ -3,7 +3,7 @@
"""Attention layer with FlashAttention."""
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional
import torch
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
@@ -14,18 +14,15 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
create_block_mask_compiled = torch.compile(create_block_mask,
fullgraph=True,
mode="reduce-overhead")
@@ -261,36 +258,34 @@ class FlexAttentionMetadata:
class FlexAttentionMetadataBuilder(
AttentionMetadataBuilder[FlexAttentionMetadata]):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
model_config = runner.model_config
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.runner = runner
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.num_heads_q = self.model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
vllm_config.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
self.device = device
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlexAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
use_cascade = common_prefix_len > 0
cu_prefix_query_lens = None
@@ -300,17 +295,15 @@ class FlexAttentionMetadataBuilder(
raise NotImplementedError("Not yet my friend")
block_size = self.kv_cache_spec.block_size
max_possible_seq_len = self.runner.model_config.max_model_len
total_cache_tokens = (self.runner.cache_config.num_gpu_blocks *
block_size)
max_possible_seq_len = self.model_config.max_model_len
total_cache_tokens = self.cache_config.num_gpu_blocks * block_size
inverse_block_table = physical_to_logical_mapping(
block_table_tensor, self.runner.cache_config.num_gpu_blocks)
block_table_tensor, self.cache_config.num_gpu_blocks)
# Get the original offset tensor
offset_tensor = torch.tensor(
self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]).to(
self.runner.device, non_blocking=True)
offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
self.device, non_blocking=True)
out = FlexAttentionMetadata(
num_actual_tokens=num_actual_tokens,

View File

@@ -7,15 +7,15 @@ from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import MambaSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
@@ -87,80 +87,24 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
block_table: BlockTable):
self.runner = runner
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
self.chunk_size = runner.vllm_config.model_config.get_mamba_chunk_size(
)
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# NOTE (Chen): Copied from MLACommonMetadataBuilder and
# FlashInferMetadataBuilder. Should be refactored later to avoid code
# duplication of these 3 functions.
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the decode run only supports num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> Mamba2AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
@@ -172,29 +116,31 @@ class Mamba2AttentionMetadataBuilder(
has_initial_states = None
prep_initial_states = False
state_indices_tensor = self.block_table.block_table[:num_reqs, 0]
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if self._num_prefills > 0:
if num_prefills > 0:
#[batch,]
has_initial_states_cpu = (
self.runner.input_batch.
num_computed_tokens_cpu_tensor[num_reqs -
self._num_prefills:num_reqs]
> 0)
common_attn_metadata.
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
prep_initial_states = torch.any(has_initial_states_cpu).item()
has_initial_states = has_initial_states_cpu.to(
query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[
-self._num_prefills - 1:] - self._num_decode_tokens
-num_prefills - 1:] - num_decode_tokens
seq_idx = torch.repeat_interleave(
torch.arange(self._num_prefills,
dtype=torch.int32,
device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=self._num_prefill_tokens)
seq_idx = torch.repeat_interleave(torch.arange(
num_prefills,
dtype=torch.int32,
device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=num_prefill_tokens)
seq_idx.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level
@@ -204,13 +150,13 @@ class Mamba2AttentionMetadataBuilder(
chunk_indices, chunk_offsets = (
_query_start_loc_to_chunk_indices_offsets(
query_start_loc_p, self.chunk_size,
self._num_prefill_tokens))
num_prefill_tokens))
attn_metadata = Mamba2AttentionMetadata(
num_prefills=self._num_prefills,
num_prefill_tokens=self._num_prefill_tokens,
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
has_initial_states=has_initial_states,

View File

@@ -202,18 +202,18 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
UnquantizedLinearMethod)
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
get_per_layer_parameters,
infer_global_hyperparameters)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
get_per_layer_parameters, infer_global_hyperparameters,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -235,7 +235,6 @@ except ImportError:
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
@@ -406,22 +405,23 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
def __init__(self,
runner: "GPUModelRunner",
kv_cache_spec: AttentionSpec,
block_table: BlockTable,
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[type[M]] = None):
self.metadata_cls = metadata_cls \
if metadata_cls is not None else MLACommonMetadata
self.runner = runner
scheduler_config = runner.scheduler_config
model_config = runner.model_config
cache_config = runner.cache_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.num_heads = model_config.get_num_attention_heads(
runner.parallel_config)
self.mla_dims = get_mla_dims(model_config)
self.aot_schedule = current_platform.is_cuda()
self.kv_cache_spec = kv_cache_spec
self.device = device
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
parallel_config = vllm_config.parallel_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.num_heads = self.model_config.get_num_attention_heads(
parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda()
# Dont try to access the runner on AMD
if self.aot_schedule:
@@ -432,7 +432,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(
8 * model_config.max_model_len, 4 *
8 * self.model_config.max_model_len, 4 *
scheduler_config.max_num_seqs * cache_config.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
@@ -447,13 +447,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
scheduler_config.max_num_seqs * cache_config.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
model_config.get_head_size()),
dtype=model_config.dtype,
device=runner.device,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
self.block_table = block_table
self._use_cudnn_prefill = use_cudnn_prefill()
self._use_fi_prefill = use_flashinfer_prefill()
self.prefill_metadata_cls = (
@@ -465,7 +463,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=runner.device)
device=device)
self._fi_prefill_main: Optional[
BatchPrefillWithRaggedKVCacheWrapper] = None
@@ -473,13 +471,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
BatchPrefillWithRaggedKVCacheWrapper] = []
self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(runner.vllm_config, MLACommonImpl))
get_per_layer_parameters(vllm_config, MLACommonImpl))
if self._use_cudnn_prefill:
self.cudnn_workspace = torch.empty(
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
dtype=torch.int8,
device=runner.device,
device=device,
)
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
@@ -505,7 +503,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
assert num_chunks <= len(self._fi_prefill_chunks)
# In MLA, the non-latent num_qo_heads == num_kv_heads
num_qo_heads = self.runner.num_query_heads
num_qo_heads = self.num_heads
num_kv_heads = num_qo_heads
# Sanity: Verify that num_kv_heads == 1 since it is latent space
@@ -531,7 +529,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.runner.dtype,
q_data_type=self.model_config.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)
@@ -552,7 +550,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.
logits_soft_cap,
q_data_type=self.runner.dtype,
q_data_type=self.model_config.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)
@@ -561,63 +559,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
@@ -639,49 +583,50 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
m.max_query_len = 1 # decode-only
# Update state usually set in reorder_batch.
self._num_decodes = m.num_reqs
self._num_decode_tokens = m.num_actual_tokens
self._num_prefills = 0
self._num_prefill_tokens = 0
return self.build(0, m)
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> M:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
assert self._num_decodes + self._num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
device = self.device
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
prefill_metadata = None
if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
context_lens_cpu = self.runner.input_batch.\
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu -
query_seq_lens_cpu)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
chunked_context_metadata = None
if self.chunked_prefill_enabled and self._num_prefills > 0 \
if self.chunked_prefill_enabled and num_prefills > 0 \
and max_context_len_cpu > 0:
# NOTE: it is recommend you read the `Chunked Prefill` section
# in the comment at the top of the file before trying to
@@ -712,14 +657,14 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# of `to_list`.
chunk_starts = \
torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, self._num_prefills) \
.unsqueeze(1).expand(-1, num_prefills) \
* max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
self._num_prefills + 1,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
@@ -762,28 +707,28 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata.cudnn_workspace = self.cudnn_workspace
decode_metadata = None
if self._num_decodes > 0:
if num_decodes > 0:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes],
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens=seq_lens[:num_decodes],
)
attn_metadata = self.metadata_cls(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
num_actual_tokens=num_actual_tokens,
num_actual_tokens=num_tokens,
query_start_loc=query_start_loc,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
head_dim=self.model_config.get_head_size(),
# MLACommonMetadata Chunk prefill specific
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
prefill=prefill_metadata,
decode=decode_metadata,
)
if self._use_fi_prefill and self._num_prefills > 0:
if self._use_fi_prefill and num_prefills > 0:
assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
self._build_fi_prefill_wrappers(attn_metadata.prefill)

View File

@@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionType,
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
@@ -18,7 +19,6 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
logger = init_logger(__name__)
@@ -56,12 +56,13 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata)
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)
self.compilation_config = vllm_config.compilation_config
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
@@ -75,7 +76,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1, # MQA for the decode path
)
if self.runner.full_cuda_graph:
if self.compilation_config.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata

View File

@@ -8,6 +8,8 @@ import torch
import vllm.envs as envs
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
from vllm.config import VllmConfig
from vllm.utils import cdiv
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
@@ -16,7 +18,6 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
# yapf: enable
@@ -65,24 +66,26 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only
def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata)
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1."
self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
# Preparing persistent buffers
if self.runner.full_cuda_graph:
device = self.runner.device
max_num_reqs = self.runner.max_num_reqs
if vllm_config.compilation_config.full_cuda_graph:
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=device)
self.paged_kv_indices = torch.zeros(
block_table.get_device_tensor().numel(
), # max num pages possible
dtype=torch.int32,
device=device)
self.paged_kv_indices = torch.zeros(max_num_pages,
dtype=torch.int32,
device=device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=device)
@@ -96,7 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size
device = self.runner.device
device = self.device
num_reqs = seq_lens.size(0)
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
@@ -113,8 +117,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
if self.runner.full_cuda_graph:
num_reqs = self._num_decodes
if self.compilation_config.full_cuda_graph:
num_actual_pages = paged_kv_indices.size(0)
@@ -137,7 +140,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
else:
qo_indptr = torch.arange(0,
self._num_decodes + 1,
num_reqs + 1,
step=1,
dtype=torch.int32,
device=device)

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with AiterFlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional
import torch
@@ -10,18 +10,13 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
make_local_attention_virtual_batches)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if current_platform.is_rocm():
import aiter
@@ -172,54 +167,49 @@ logger = init_logger(__name__)
class AiterFlashAttentionMetadataBuilder:
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
model_config = runner.model_config
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.device = device
self.runner = runner
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
def reorder_batch(self, input_batch, scheduler_output) -> bool:
return False
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum())
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum())
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
dtype=torch.int32,
device="cuda")
device=self.device)
torch.cumsum(seq_lens,
dim=0,
dtype=cu_seq_lens.dtype,
@@ -231,21 +221,21 @@ class AiterFlashAttentionMetadataBuilder:
# for local attention
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
if self.model_config.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
self.model_config.attention_chunk_size,
query_start_loc_cpu.numpy(),
seq_lens_cpu.numpy(),
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.runner.device, non_blocking=True)
self.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True)
local_max_query_len = int(seqlens_q_local_np.max())
local_max_seq_len = int(virt_k_seqlens_np.max())
self.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max().item()
local_max_seq_len = virt_k_seqlens_np.max().item()
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
@@ -256,12 +246,11 @@ class AiterFlashAttentionMetadataBuilder:
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
dtype=torch.int32,
device=self.runner.device)
device=self.device)
local_cu_seq_lens[1:] = torch.cumsum(
torch.from_numpy(virt_k_seqlens_np).to(
device=self.runner.device,
dtype=torch.int32,
non_blocking=True),
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
dtype=torch.int32,
non_blocking=True),
dim=0)

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from typing import Any, ClassVar, Optional
import torch
@@ -14,6 +14,7 @@ from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
@@ -21,10 +22,6 @@ from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
make_local_attention_virtual_batches)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
@@ -75,12 +72,21 @@ class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
self.runner = runner
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.device = device
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
vllm_config.parallel_config)
self.headdim = model_config.get_head_size()
self.attention_chunk_size = getattr(vllm_config.scheduler_config,
'attention_chunk_size', None)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
@@ -92,46 +98,36 @@ class TritonAttentionMetadataBuilder(
attn_metadata.seq_lens.fill_(1)
return attn_metadata
def build(
self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> TritonAttentionMetadata:
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
# for local attention
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
if self.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
self.attention_chunk_size,
common_attn_metadata.query_start_loc_cpu.numpy(),
common_attn_metadata.seq_lens_cpu.numpy(),
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.runner.device, non_blocking=True)
self.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max()
local_max_seq_len = virt_k_seqlens_np.max()
self.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max().item()
local_max_seq_len = virt_k_seqlens_np.max().item()
local_attn_metadata = TritonAttentionMetadata \
.LocalAttentionMetadata(
@@ -148,14 +144,13 @@ class TritonAttentionMetadataBuilder(
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
device=self.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.runner.device)
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
device=self.device)
suffix_kv_lens = (common_attn_metadata.seq_lens_cpu -
common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
self.runner.device)
suffix_kv_lens = suffix_kv_lens.to(self.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None

View File

@@ -22,6 +22,7 @@ import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
_KV_CACHE_LAYOUT_OVERRIDE = None
@@ -32,14 +33,22 @@ class CommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_computed_tokens_cpu: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
@@ -47,6 +56,14 @@ class CommonAttentionMetadata:
max_query_len: int
"""Longest query in batch"""
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
def __post_init__(self):
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
self.slot_mapping[self.num_actual_tokens:].fill_(-1)
M = TypeVar("M")
@@ -56,11 +73,25 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
full_cudagraph_supported: ClassVar[bool] = False
@abstractmethod
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.kv_cache_spec = kv_cache_spec
@abstractmethod
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
Args:
common_prefix_len: The length of the common prefix of the batch.
common_attn_metadata: The common attention metadata.
fast_build: The meta-data will prioritize speed of building over
then speed at execution. Can be used for spec-decode where the
result of a build call may only be used for few layers/iters.
"""
raise NotImplementedError
@@ -351,3 +382,108 @@ def make_local_attention_virtual_batches(
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
block_table_local
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
) -> tuple[int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
Returns:
num_decodes: The number of decode requests.
num_prefills: The number of prefill requests.
num_decode_tokens: The number of tokens in the decode requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len = common_attn_metadata.max_query_len
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
if max_query_len <= decode_threshold:
return num_reqs, 0, num_tokens, 0
query_lens = query_start_loc[1:] - query_start_loc[:-1]
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[first_prefill:] > decode_threshold)
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
num_decode_tokens = query_start_loc[first_prefill].item()
num_prefill_tokens = num_tokens - num_decode_tokens
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
) -> bool:
"""
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the back using the least
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
# requests where attention is likely memory-bound and "prefill" to mean
# requests where attention is likely compute-bound, TODO(lucas): figure out
# a better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens <= decode_threshold:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
return modified_batch