Replace FlashAttention with xformers (#70)

This commit is contained in:
Woosuk Kwon
2023-05-05 02:01:08 -07:00
committed by GitHub
parent 189ae23133
commit c9d5b6d4a8
13 changed files with 89 additions and 133 deletions

View File

@@ -1,8 +1,8 @@
from typing import Optional
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
import torch.nn as nn
from xformers import ops as xops
from cacheflow import attention_ops
from cacheflow import cache_ops
@@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()
def multi_query_kv_attention(
self,
@@ -22,32 +23,21 @@ class GPTCacheFlowAttention(nn.Module):
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
max_prompt_len: int,
attn_bias: xops.AttentionBias,
) -> None:
if query.dtype == torch.float:
raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.')
head_size = query.shape[-1]
if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.')
# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward(
query,
key,
value,
output,
cumulative_prompt_lens,
cumulative_prompt_lens,
max_prompt_len,
max_prompt_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
return_softmax=False,
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0))
return output
def single_query_cached_kv_attention(
self,
@@ -109,8 +99,7 @@ class GPTCacheFlowAttention(nn.Module):
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.cumulative_prompt_lens,
input_metadata.max_prompt_len,
input_metadata.attn_bias,
)
# Wait until the cache op is done.
@@ -143,13 +132,6 @@ class GPTCacheFlowAttention(nn.Module):
return output.view(-1, num_heads * head_size)
class OPTCacheFlowAttention(GPTCacheFlowAttention):
"""OPT uses the same attention mechanism as GPT."""
def __init__(self, scale: float) -> None:
super().__init__(scale)
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
"""Attention with GPT-NeoX style rotary embedding."""
@@ -207,7 +189,3 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
input_metadata,
cache_event,
)
class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
"""LLaMA uses the GPT-NeoX style rotary embedding."""

View File

@@ -1,6 +1,7 @@
from typing import List, Dict, Tuple
import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from cacheflow.sampling_params import SamplingParams
@@ -12,7 +13,6 @@ class InputMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int],
cumulative_prompt_lens: torch.Tensor,
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
@@ -21,15 +21,14 @@ class InputMetadata:
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens
self.cumulative_prompt_lens = cumulative_prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables
self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
@@ -41,15 +40,13 @@ class InputMetadata:
def __repr__(self) -> str:
return (f'InputMetadata('
f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'max_prompt_len={self.max_prompt_len}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'max_context_len={self.max_context_len}), '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
f'slot_mapping={self.slot_mapping}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'block_tables={self.block_tables})')
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}), '
f'slot_mapping={self.slot_mapping}')

View File

@@ -7,7 +7,7 @@ from transformers import LlamaConfig
from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
@@ -79,7 +79,7 @@ class LlamaAttention(nn.Module):
input_is_parallel=True,
perform_initialization=False,
)
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
def forward(
self,

View File

@@ -202,8 +202,8 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
@@ -277,8 +277,8 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
@@ -353,8 +353,8 @@ class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size

View File

@@ -6,7 +6,7 @@ from torch import nn
from transformers import OPTConfig
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.attention import GPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
@@ -55,7 +55,7 @@ class OPTAttention(nn.Module):
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.attn = OPTCacheFlowAttention(scale=self.scaling)
self.attn = GPTCacheFlowAttention(scale=self.scaling)
def forward(
self,