Replace FlashAttention with xformers (#70)

This commit is contained in:
Woosuk Kwon
2023-05-05 02:01:08 -07:00
committed by GitHub
parent 189ae23133
commit c9d5b6d4a8
13 changed files with 89 additions and 133 deletions

View File

@@ -1,8 +1,8 @@
from typing import Optional
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
import torch.nn as nn
from xformers import ops as xops
from cacheflow import attention_ops
from cacheflow import cache_ops
@@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()
def multi_query_kv_attention(
self,
@@ -22,32 +23,21 @@ class GPTCacheFlowAttention(nn.Module):
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
max_prompt_len: int,
attn_bias: xops.AttentionBias,
) -> None:
if query.dtype == torch.float:
raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.')
head_size = query.shape[-1]
if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.')
# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward(
query,
key,
value,
output,
cumulative_prompt_lens,
cumulative_prompt_lens,
max_prompt_len,
max_prompt_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
return_softmax=False,
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0))
return output
def single_query_cached_kv_attention(
self,
@@ -109,8 +99,7 @@ class GPTCacheFlowAttention(nn.Module):
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.cumulative_prompt_lens,
input_metadata.max_prompt_len,
input_metadata.attn_bias,
)
# Wait until the cache op is done.
@@ -143,13 +132,6 @@ class GPTCacheFlowAttention(nn.Module):
return output.view(-1, num_heads * head_size)
class OPTCacheFlowAttention(GPTCacheFlowAttention):
"""OPT uses the same attention mechanism as GPT."""
def __init__(self, scale: float) -> None:
super().__init__(scale)
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
"""Attention with GPT-NeoX style rotary embedding."""
@@ -207,7 +189,3 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
input_metadata,
cache_event,
)
class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
"""LLaMA uses the GPT-NeoX style rotary embedding."""