Replace FlashAttention with xformers (#70)
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from xformers import ops as xops
|
||||
|
||||
from cacheflow import attention_ops
|
||||
from cacheflow import cache_ops
|
||||
@@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
@@ -22,32 +23,21 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
|
||||
max_prompt_len: int,
|
||||
attn_bias: xops.AttentionBias,
|
||||
) -> None:
|
||||
if query.dtype == torch.float:
|
||||
raise ValueError('The float data type is not supported by '
|
||||
'FlashAttention. Use the half data type instead.')
|
||||
head_size = query.shape[-1]
|
||||
if head_size > 128:
|
||||
raise ValueError('FlashAttention does not support head_size > 128.')
|
||||
|
||||
# Directly call FlashAttention's internal function to avoid allocating
|
||||
# a new tensor for the output.
|
||||
_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
cumulative_prompt_lens,
|
||||
cumulative_prompt_lens,
|
||||
max_prompt_len,
|
||||
max_prompt_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax=False,
|
||||
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.squeeze(0))
|
||||
return output
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
@@ -109,8 +99,7 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
query[:num_prompt_tokens],
|
||||
key[:num_prompt_tokens],
|
||||
value[:num_prompt_tokens],
|
||||
input_metadata.cumulative_prompt_lens,
|
||||
input_metadata.max_prompt_len,
|
||||
input_metadata.attn_bias,
|
||||
)
|
||||
|
||||
# Wait until the cache op is done.
|
||||
@@ -143,13 +132,6 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
return output.view(-1, num_heads * head_size)
|
||||
|
||||
|
||||
class OPTCacheFlowAttention(GPTCacheFlowAttention):
|
||||
"""OPT uses the same attention mechanism as GPT."""
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__(scale)
|
||||
|
||||
|
||||
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
"""Attention with GPT-NeoX style rotary embedding."""
|
||||
|
||||
@@ -207,7 +189,3 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
|
||||
|
||||
class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
|
||||
"""LLaMA uses the GPT-NeoX style rotary embedding."""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
|
||||
@@ -12,7 +13,6 @@ class InputMetadata:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
||||
prompt_lens: List[int],
|
||||
cumulative_prompt_lens: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
@@ -21,15 +21,14 @@ class InputMetadata:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_logprobs = seq_logprobs
|
||||
self.prompt_lens = prompt_lens
|
||||
self.cumulative_prompt_lens = cumulative_prompt_lens
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
self.max_context_len = max_context_len
|
||||
self.block_tables = block_tables
|
||||
|
||||
self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
||||
self.num_prompts = len(prompt_lens)
|
||||
self.num_prompt_tokens = sum(prompt_lens)
|
||||
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
||||
self.num_generation_tokens = context_lens.shape[0]
|
||||
self.num_valid_tokens = slot_mapping.shape[0]
|
||||
if block_tables.numel() > 0:
|
||||
@@ -41,15 +40,13 @@ class InputMetadata:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'InputMetadata('
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'max_prompt_len={self.max_prompt_len}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'num_valid_tokens={self.num_valid_tokens}, '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'max_context_len={self.max_context_len}), '
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'prompt_lens={self.prompt_lens}, '
|
||||
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
|
||||
f'slot_mapping={self.slot_mapping}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'context_lens={self.context_lens}, '
|
||||
f'block_tables={self.block_tables})')
|
||||
f'max_context_len={self.max_context_len}), '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'block_tables={self.block_tables}), '
|
||||
f'slot_mapping={self.slot_mapping}')
|
||||
|
||||
@@ -7,7 +7,7 @@ from transformers import LlamaConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.activation import SiluAndMul
|
||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
||||
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.models.layernorm import RMSNorm
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
@@ -79,7 +79,7 @@ class LlamaAttention(nn.Module):
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
|
||||
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -202,8 +202,8 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
# estimating
|
||||
# 1) the maximum activation tensor size during inference
|
||||
# 2) the residual tensor size during inference
|
||||
# Here, we assume that FlashAttention is used and
|
||||
# thus the attention maps are never materialized in GPU DRAM.
|
||||
# Here, we assume that we use memory-efficient attention which
|
||||
# does not materialize the attention maps in GPU DRAM.
|
||||
residual = max_num_batched_tokens * self.hidden_size
|
||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
||||
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
|
||||
@@ -277,8 +277,8 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
# estimating
|
||||
# 1) the maximum activation tensor size during inference
|
||||
# 2) the residual tensor size during inference
|
||||
# Here, we assume that FlashAttention is used and
|
||||
# thus the attention maps are never materialized in GPU DRAM.
|
||||
# Here, we assume that we use memory-efficient attention which
|
||||
# does not materialize the attention maps in GPU DRAM.
|
||||
residual = max_num_batched_tokens * self.hidden_size
|
||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
||||
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
|
||||
@@ -353,8 +353,8 @@ class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
# estimating
|
||||
# 1) the maximum activation tensor size during inference
|
||||
# 2) the residual tensor size during inference
|
||||
# Here, we assume that FlashAttention is used and
|
||||
# thus the attention maps are never materialized in GPU DRAM.
|
||||
# Here, we assume that we use memory-efficient attention which
|
||||
# does not materialize the attention maps in GPU DRAM.
|
||||
residual = max_num_batched_tokens * self.hidden_size
|
||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
||||
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
|
||||
|
||||
@@ -6,7 +6,7 @@ from torch import nn
|
||||
from transformers import OPTConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import OPTCacheFlowAttention
|
||||
from cacheflow.models.attention import GPTCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
@@ -55,7 +55,7 @@ class OPTAttention(nn.Module):
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = OPTCacheFlowAttention(scale=self.scaling)
|
||||
self.attn = GPTCacheFlowAttention(scale=self.scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user