[V1] Implement Cascade Attention (#11635)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -72,6 +72,8 @@ class GPUModelRunner:
|
||||
# Model-related.
|
||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
self.num_query_heads = model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
@@ -118,6 +120,10 @@ class GPUModelRunner:
|
||||
self.cudagraph_batch_sizes = list(
|
||||
reversed(self.vllm_config.compilation_config.capture_sizes))
|
||||
|
||||
# Cache the device properties.
|
||||
self.device_properties = torch.cuda.get_device_properties(self.device)
|
||||
self.num_sms = self.device_properties.multi_processor_count
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
@@ -131,7 +137,8 @@ class GPUModelRunner:
|
||||
device=self.device)
|
||||
|
||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||
self.arange_np = np.arange(max(self.max_num_reqs, self.max_model_len),
|
||||
self.arange_np = np.arange(max(self.max_num_reqs + 1,
|
||||
self.max_model_len),
|
||||
dtype=np.int32)
|
||||
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
||||
# a faster version of creating a new tensor every time. Thus, we should
|
||||
@@ -355,6 +362,88 @@ class GPUModelRunner:
|
||||
self.device, non_blocking=True)
|
||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||
self.device, non_blocking=True).long()
|
||||
|
||||
# Prepare for cascade attention if needed.
|
||||
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
|
||||
self.block_size)
|
||||
if common_prefix_len == 0:
|
||||
# Common case.
|
||||
use_cascade = False
|
||||
else:
|
||||
# NOTE(woosuk): Cascade attention uses two attention kernels: one
|
||||
# for the common prefix and the other for the rest. For the first
|
||||
# kernel, we concatenate all the query tokens (possibly from
|
||||
# different requests) and treat them as if they are from the same
|
||||
# request. Then, we use bi-directional attention to process the
|
||||
# common prefix in the KV cache. Importantly, this means that the
|
||||
# first kernel does not do any masking.
|
||||
|
||||
# Consider the following example:
|
||||
# Request 1's input query: [D, E, X]
|
||||
# Request 1's kv cache: [A, B, C, D, E, X]
|
||||
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
|
||||
# Request 2's input query: [E, Y]
|
||||
# Request 2's kv cache: [A, B, C, D, E, Y]
|
||||
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
||||
|
||||
# If we use [A, B, C, D, E] as the common prefix, then the
|
||||
# first kernel will compute the bi-directional attention between
|
||||
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
|
||||
# However, this is wrong because D in Request 1 should not attend to
|
||||
# E in the common prefix (i.e., we need masking).
|
||||
# To avoid this, [A, B, C, D] should be the common prefix.
|
||||
# That is, the common prefix should be capped by the minimum
|
||||
# num_computed_tokens among the requests, and plus one to include
|
||||
# the first token of the query.
|
||||
|
||||
# In practice, we use [A, B, C] as the common prefix, instead of
|
||||
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
|
||||
# num_computed_tokens, without plus one).
|
||||
# This is because of an implementation detail: We want to always
|
||||
# use two kernels for cascade attention. Let's imagine:
|
||||
# Request 3's input query: [D]
|
||||
# Request 3's kv cache: [A, B, C, D]
|
||||
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
||||
# If we use [A, B, C, D] as the common prefix for Request 1-3,
|
||||
# then Request 3 will be processed only by the first kernel,
|
||||
# and the second kernel will get an empty input. While this is not
|
||||
# a fundamental problem, our current implementation does not support
|
||||
# this case.
|
||||
common_prefix_len = min(
|
||||
common_prefix_len,
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
|
||||
# common_prefix_len should be a multiple of the block size.
|
||||
common_prefix_len = (common_prefix_len // self.block_size *
|
||||
self.block_size)
|
||||
use_cascade = FlashAttentionBackend.use_cascade_attention(
|
||||
common_prefix_len=common_prefix_len,
|
||||
query_lens=num_scheduled_tokens,
|
||||
num_query_heads=self.num_query_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
use_alibi=False, # FIXME
|
||||
use_sliding_window=self.sliding_window is not None,
|
||||
num_sms=self.num_sms,
|
||||
)
|
||||
|
||||
if use_cascade:
|
||||
# TODO: Optimize.
|
||||
cu_prefix_query_lens = torch.tensor(
|
||||
[0, total_num_scheduled_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
cu_suffix_kv_lens = (
|
||||
self.seq_start_loc_np[:num_reqs + 1] -
|
||||
self.arange_np[:num_reqs + 1] * common_prefix_len)
|
||||
cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to(
|
||||
self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
cu_prefix_kv_lens = None
|
||||
cu_suffix_kv_lens = None
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
@@ -363,6 +452,11 @@ class GPUModelRunner:
|
||||
seq_start_loc=seq_start_loc,
|
||||
block_table=self.input_batch.block_table[:num_reqs],
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
cu_prefix_kv_lens=cu_prefix_kv_lens,
|
||||
cu_suffix_kv_lens=cu_suffix_kv_lens,
|
||||
)
|
||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||
# request in the batch. While we should not sample any token from this
|
||||
|
||||
Reference in New Issue
Block a user