Replace FlashAttention with xformers (#70)
This commit is contained in:
@@ -7,7 +7,7 @@ from transformers import LlamaConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.activation import SiluAndMul
|
||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
||||
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.models.layernorm import RMSNorm
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
@@ -79,7 +79,7 @@ class LlamaAttention(nn.Module):
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
|
||||
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user