Replace FlashAttention with xformers (#70)

This commit is contained in:
Woosuk Kwon
2023-05-05 02:01:08 -07:00
committed by GitHub
parent 189ae23133
commit c9d5b6d4a8
13 changed files with 89 additions and 133 deletions

View File

@@ -7,7 +7,7 @@ from transformers import LlamaConfig
from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
@@ -79,7 +79,7 @@ class LlamaAttention(nn.Module):
input_is_parallel=True,
perform_initialization=False,
)
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
def forward(
self,