[V1] Implement Cascade Attention (#11635)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-01-01 21:56:46 +09:00
committed by GitHub
parent 6d70198b17
commit 73001445fb
10 changed files with 696 additions and 19 deletions

View File

@@ -72,6 +72,8 @@ class GPUModelRunner:
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
@@ -118,6 +120,10 @@ class GPUModelRunner:
self.cudagraph_batch_sizes = list(
reversed(self.vllm_config.compilation_config.capture_sizes))
# Cache the device properties.
self.device_properties = torch.cuda.get_device_properties(self.device)
self.num_sms = self.device_properties.multi_processor_count
# Persistent buffers for CUDA graphs.
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
@@ -131,7 +137,8 @@ class GPUModelRunner:
device=self.device)
# OPTIMIZATION: Cache the tensors rather than creating them every step.
self.arange_np = np.arange(max(self.max_num_reqs, self.max_model_len),
self.arange_np = np.arange(max(self.max_num_reqs + 1,
self.max_model_len),
dtype=np.int32)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
@@ -355,6 +362,88 @@ class GPUModelRunner:
self.device, non_blocking=True)
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True).long()
# Prepare for cascade attention if needed.
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
self.block_size)
if common_prefix_len == 0:
# Common case.
use_cascade = False
else:
# NOTE(woosuk): Cascade attention uses two attention kernels: one
# for the common prefix and the other for the rest. For the first
# kernel, we concatenate all the query tokens (possibly from
# different requests) and treat them as if they are from the same
# request. Then, we use bi-directional attention to process the
# common prefix in the KV cache. Importantly, this means that the
# first kernel does not do any masking.
# Consider the following example:
# Request 1's input query: [D, E, X]
# Request 1's kv cache: [A, B, C, D, E, X]
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
# Request 2's input query: [E, Y]
# Request 2's kv cache: [A, B, C, D, E, Y]
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
# If we use [A, B, C, D, E] as the common prefix, then the
# first kernel will compute the bi-directional attention between
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
# However, this is wrong because D in Request 1 should not attend to
# E in the common prefix (i.e., we need masking).
# To avoid this, [A, B, C, D] should be the common prefix.
# That is, the common prefix should be capped by the minimum
# num_computed_tokens among the requests, and plus one to include
# the first token of the query.
# In practice, we use [A, B, C] as the common prefix, instead of
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
# num_computed_tokens, without plus one).
# This is because of an implementation detail: We want to always
# use two kernels for cascade attention. Let's imagine:
# Request 3's input query: [D]
# Request 3's kv cache: [A, B, C, D]
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
# If we use [A, B, C, D] as the common prefix for Request 1-3,
# then Request 3 will be processed only by the first kernel,
# and the second kernel will get an empty input. While this is not
# a fundamental problem, our current implementation does not support
# this case.
common_prefix_len = min(
common_prefix_len,
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // self.block_size *
self.block_size)
use_cascade = FlashAttentionBackend.use_cascade_attention(
common_prefix_len=common_prefix_len,
query_lens=num_scheduled_tokens,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
use_alibi=False, # FIXME
use_sliding_window=self.sliding_window is not None,
num_sms=self.num_sms,
)
if use_cascade:
# TODO: Optimize.
cu_prefix_query_lens = torch.tensor(
[0, total_num_scheduled_tokens],
dtype=torch.int32,
device=self.device)
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
dtype=torch.int32,
device=self.device)
cu_suffix_kv_lens = (
self.seq_start_loc_np[:num_reqs + 1] -
self.arange_np[:num_reqs + 1] * common_prefix_len)
cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to(
self.device)
else:
cu_prefix_query_lens = None
cu_prefix_kv_lens = None
cu_suffix_kv_lens = None
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
@@ -363,6 +452,11 @@ class GPUModelRunner:
seq_start_loc=seq_start_loc,
block_table=self.input_batch.block_table[:num_reqs],
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
cu_prefix_kv_lens=cu_prefix_kv_lens,
cu_suffix_kv_lens=cu_suffix_kv_lens,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this