Optimize data movement (#20)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -16,40 +16,38 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
|
||||
self.flash_attn = FlashAttention(softmax_scale=self.scale)
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
prompt_lens: List[int],
|
||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
|
||||
max_prompt_len: int,
|
||||
) -> None:
|
||||
if query.dtype == torch.float:
|
||||
raise ValueError('The float data type is not supported by '
|
||||
'FlashAttention. Use the half data type instead.')
|
||||
head_size = query.shape[2]
|
||||
head_size = query.shape[-1]
|
||||
if head_size > 128:
|
||||
raise ValueError('FlashAttention does not support head_size > 128.')
|
||||
|
||||
device = query.device
|
||||
prefix_sum = [0]
|
||||
for prompt_len in prompt_lens:
|
||||
prefix_sum.append(prefix_sum[-1] + prompt_len)
|
||||
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
|
||||
max_prompt_len = max(prompt_lens)
|
||||
|
||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
||||
qkv = torch.stack([query, key, value], dim=1)
|
||||
out = self.flash_attn(
|
||||
qkv,
|
||||
cu_seqlens=prefix_sum,
|
||||
max_s=max_prompt_len,
|
||||
# Directly call FlashAttention's internal function to avoid allocating
|
||||
# a new tensor for the output.
|
||||
_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
cumulative_prompt_lens,
|
||||
cumulative_prompt_lens,
|
||||
max_prompt_len,
|
||||
max_prompt_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)[0]
|
||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
||||
output.copy_(out, non_blocking=True)
|
||||
return_softmax=False,
|
||||
)
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
@@ -90,21 +88,18 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
# Pre-allocate the output tensor.
|
||||
output = torch.empty_like(query)
|
||||
# NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||
|
||||
# Prune out paddings if any.
|
||||
query = query[:input_metadata.num_valid_tokens]
|
||||
key = key[:input_metadata.num_valid_tokens]
|
||||
value = value[:input_metadata.num_valid_tokens]
|
||||
|
||||
# Reshape the input tensors.
|
||||
# Reshape the query, key, and value tensors.
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_heads, head_size)
|
||||
value = value.view(-1, num_heads, head_size)
|
||||
output = output.view(-1, num_heads, head_size)
|
||||
|
||||
# Pre-allocate the output tensor.
|
||||
output = torch.empty_like(query)
|
||||
|
||||
# Compute the attention op for prompts.
|
||||
num_prompt_tokens = input_metadata.num_prompt_tokens
|
||||
@@ -114,7 +109,8 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
query[:num_prompt_tokens],
|
||||
key[:num_prompt_tokens],
|
||||
value[:num_prompt_tokens],
|
||||
input_metadata.prompt_lens,
|
||||
input_metadata.cumulative_prompt_lens,
|
||||
input_metadata.max_prompt_len,
|
||||
)
|
||||
|
||||
# Wait until the cache op is done.
|
||||
@@ -122,14 +118,22 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
cache_event.wait()
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, input_metadata.slot_mapping)
|
||||
num_valid_tokens = input_metadata.num_valid_tokens
|
||||
if num_valid_tokens > 0:
|
||||
# The stride is 3 because the key and value are sliced from qkv.
|
||||
cache_ops.reshape_and_cache(
|
||||
key[:num_valid_tokens],
|
||||
value[:num_valid_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.slot_mapping,
|
||||
)
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
# Compute the attention op for generation tokens.
|
||||
self.single_query_cached_kv_attention(
|
||||
output[num_prompt_tokens:],
|
||||
query[num_prompt_tokens:],
|
||||
output[num_prompt_tokens:num_valid_tokens],
|
||||
query[num_prompt_tokens:num_valid_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata)
|
||||
@@ -186,19 +190,15 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
out_query = torch.empty_like(query)
|
||||
out_key = torch.empty_like(key)
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
out_query,
|
||||
out_key,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.cos_sin_cache,
|
||||
)
|
||||
return super().forward(
|
||||
out_query,
|
||||
out_key,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
|
||||
Reference in New Issue
Block a user