[misc] use out argument for flash attention (#10822)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-02 02:50:10 -08:00
committed by GitHub
parent e95f275f57
commit a4c4daf364
13 changed files with 141 additions and 154 deletions

View File

@@ -6,8 +6,6 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -113,13 +111,14 @@ class FlashAttentionImpl(AttentionImpl):
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
@@ -135,118 +134,42 @@ class FlashAttentionImpl(AttentionImpl):
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the CPU
# overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if attn_metadata is None:
# Profiling run.
return output
output = torch.empty_like(query)
torch.ops.vllm.unified_v1_flash_attention(
output,
query,
key,
value,
self.num_heads,
self.head_size,
self.num_kv_heads,
kv_cache,
num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the input keys and values and store them in the cache.
key_cache = kv_cache[0]
value_cache = kv_cache[1]
torch.ops._C_cache_ops.reshape_and_cache_flash(
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
return output.view(-1, self.num_heads * self.head_size)
# Compute attention and update output up to `num_actual_tokens`.
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
)
def unified_v1_flash_attention(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
context = get_forward_context()
current_metadata = context.dynamic_forward_context
if current_metadata is None:
# Profiling run.
return
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the input keys and values and store them in the cache.
key_cache = kv_cache[0]
value_cache = kv_cache[1]
torch.ops._C_cache_ops.reshape_and_cache_flash(
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
attn_metadata.slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
# Compute attention and update output up to `num_actual_tokens`.
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
window_size=window_size,
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
def unified_v1_flash_attention_fake(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_v1_flash_attention",
op_func=unified_v1_flash_attention,
mutates_args=["kv_cache", "output"],
fake_impl=unified_v1_flash_attention_fake,
)
return output