[Attention] Refactor attention metadata builder interface (#20466)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -15,22 +15,20 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
PerLayerParameters,
|
||||
get_kv_cache_layout,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
|
||||
get_kv_cache_layout, get_per_layer_parameters,
|
||||
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
|
||||
@@ -226,9 +224,9 @@ class FlashInferMetadata:
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
self.runner = runner
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
||||
device: torch.device):
|
||||
self.device = device
|
||||
self._workspace_buffer = None
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
self._decode_wrapper = None # Wrapper for decode
|
||||
@@ -237,75 +235,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# Global hyperparameters shared by all attention layers
|
||||
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||
|
||||
self.vllm_config = runner.vllm_config
|
||||
self.vllm_config = vllm_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the decode run only supports num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.runner.device)
|
||||
device=self.device)
|
||||
return self._workspace_buffer
|
||||
|
||||
def _get_prefill_wrapper(self):
|
||||
@@ -316,10 +261,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def _get_decode_wrapper(self):
|
||||
if self._decode_wrapper is None:
|
||||
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config))
|
||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config)
|
||||
num_qo_heads = (
|
||||
self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config))
|
||||
num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
|
||||
self.vllm_config.parallel_config)
|
||||
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
@@ -334,7 +280,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
2, self._get_workspace_buffer(), get_kv_cache_layout())
|
||||
return self._cascade_wrapper
|
||||
|
||||
def _plan(self, attn_metadata: FlashInferMetadata):
|
||||
def _plan(self, num_prefills: int, num_decodes: int,
|
||||
attn_metadata: FlashInferMetadata):
|
||||
if self.global_hyperparameters is None:
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
|
||||
@@ -369,16 +316,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
if self._num_prefills > 0:
|
||||
if num_prefills > 0:
|
||||
# Decodes are first so prefills start after the last decode
|
||||
prefill_start = self._num_decodes
|
||||
prefill_start = num_decodes
|
||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
||||
assert attn_metadata.qo_indptr[prefill_start:].shape[
|
||||
0] == self._num_prefills + 1
|
||||
0] == num_prefills + 1
|
||||
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
|
||||
0] == self._num_prefills + 1
|
||||
0] == num_prefills + 1
|
||||
assert attn_metadata.paged_kv_last_page_len[
|
||||
prefill_start:].shape[0] == self._num_prefills
|
||||
prefill_start:].shape[0] == num_prefills
|
||||
# Since prefill_wrapper.run() will be called with
|
||||
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
||||
# to be relative to the start of the prefill queries.
|
||||
@@ -402,17 +349,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
)
|
||||
|
||||
if self._num_decodes > 0:
|
||||
if num_decodes > 0:
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
self._num_decodes, attn_metadata.max_seq_len,
|
||||
num_decodes, attn_metadata.max_seq_len,
|
||||
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
attn_metadata.decode_wrapper.plan(
|
||||
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
|
||||
attn_metadata.paged_kv_indptr[:num_decodes + 1],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len[:self.
|
||||
_num_decodes],
|
||||
attn_metadata.paged_kv_last_page_len[:num_decodes],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
@@ -427,22 +373,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
)
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> FlashInferMetadata:
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
assert (self._num_decode_tokens +
|
||||
self._num_prefill_tokens == num_actual_tokens)
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
device = self.runner.device
|
||||
device = self.device
|
||||
qo_indptr = common_attn_metadata.query_start_loc
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
|
||||
@@ -487,7 +431,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
cache_dtype = self.runner.cache_config.cache_dtype
|
||||
cache_dtype = self.cache_config.cache_dtype
|
||||
if cache_dtype.startswith("fp8"):
|
||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
cache_dtype)
|
||||
@@ -499,17 +443,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
num_qo_heads=self.runner.num_query_heads,
|
||||
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config),
|
||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||
head_dim=self.kv_cache_spec.head_size,
|
||||
page_size=page_size,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
q_data_type=self.runner.dtype,
|
||||
slot_mapping=slot_mapping,
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=self._num_prefills,
|
||||
num_prefill_tokens=self._num_prefill_tokens,
|
||||
q_data_type=self.vllm_config.model_config.dtype,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
use_cascade=use_cascade,
|
||||
shared_qo_indptr=shared_qo_indptr,
|
||||
shared_kv_page_indptr=shared_kv_page_indptr,
|
||||
@@ -521,12 +466,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
)
|
||||
|
||||
self._plan(attn_metadata)
|
||||
self._plan(num_prefills, num_decodes, attn_metadata)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
|
||||
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
# kv cache dtype to something different from query dtype.
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user