Replace FlashAttention with xformers (#70)
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from xformers import ops as xops
|
||||
|
||||
from cacheflow import attention_ops
|
||||
from cacheflow import cache_ops
|
||||
@@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
@@ -22,32 +23,21 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
|
||||
max_prompt_len: int,
|
||||
attn_bias: xops.AttentionBias,
|
||||
) -> None:
|
||||
if query.dtype == torch.float:
|
||||
raise ValueError('The float data type is not supported by '
|
||||
'FlashAttention. Use the half data type instead.')
|
||||
head_size = query.shape[-1]
|
||||
if head_size > 128:
|
||||
raise ValueError('FlashAttention does not support head_size > 128.')
|
||||
|
||||
# Directly call FlashAttention's internal function to avoid allocating
|
||||
# a new tensor for the output.
|
||||
_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
cumulative_prompt_lens,
|
||||
cumulative_prompt_lens,
|
||||
max_prompt_len,
|
||||
max_prompt_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax=False,
|
||||
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.squeeze(0))
|
||||
return output
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
@@ -109,8 +99,7 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
query[:num_prompt_tokens],
|
||||
key[:num_prompt_tokens],
|
||||
value[:num_prompt_tokens],
|
||||
input_metadata.cumulative_prompt_lens,
|
||||
input_metadata.max_prompt_len,
|
||||
input_metadata.attn_bias,
|
||||
)
|
||||
|
||||
# Wait until the cache op is done.
|
||||
@@ -143,13 +132,6 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
return output.view(-1, num_heads * head_size)
|
||||
|
||||
|
||||
class OPTCacheFlowAttention(GPTCacheFlowAttention):
|
||||
"""OPT uses the same attention mechanism as GPT."""
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__(scale)
|
||||
|
||||
|
||||
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
"""Attention with GPT-NeoX style rotary embedding."""
|
||||
|
||||
@@ -207,7 +189,3 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
|
||||
|
||||
class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
|
||||
"""LLaMA uses the GPT-NeoX style rotary embedding."""
|
||||
|
||||
Reference in New Issue
Block a user