Implement PagedAttention V2 (#1348)
This commit is contained in:
@@ -15,6 +15,8 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
RotaryEmbedding)
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
class PagedAttention(nn.Module):
|
||||
@@ -130,6 +132,14 @@ class PagedAttention(nn.Module):
|
||||
output.copy_(out.squeeze(0))
|
||||
return output
|
||||
|
||||
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
|
||||
"""Returns the slopes for the alibi attention bias.
|
||||
|
||||
Returns:
|
||||
slopes: shape = [num_heads]
|
||||
"""
|
||||
return None
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
@@ -137,6 +147,7 @@ class PagedAttention(nn.Module):
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
"""PagedAttention for the generation tokens.
|
||||
|
||||
@@ -148,21 +159,65 @@ class PagedAttention(nn.Module):
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
alibi_slopes: shape = [num_heads]
|
||||
"""
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
None, # alibi_slopes
|
||||
)
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (
|
||||
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
attention_ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
attention_ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -253,7 +308,7 @@ class PagedAttention(nn.Module):
|
||||
self.single_query_cached_kv_attention(
|
||||
output[num_prompt_tokens:num_valid_tokens],
|
||||
query[num_prompt_tokens:num_valid_tokens], key_cache,
|
||||
value_cache, input_metadata)
|
||||
value_cache, input_metadata, self.get_alibi_slopes())
|
||||
|
||||
# Reshape the output tensor.
|
||||
# NOTE(woosuk): The output tensor may include paddings.
|
||||
@@ -431,36 +486,5 @@ class PagedAttentionWithALiBi(PagedAttention):
|
||||
start += prompt_len
|
||||
return output
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
"""PagedAttention with ALiBi bias for the generation tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_generation_tokens, num_heads, head_size]
|
||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
|
||||
return self.alibi_slopes
|
||||
|
||||
Reference in New Issue
Block a user