Replace FlashAttention with xformers (#70)
This commit is contained in:
@@ -6,7 +6,7 @@ from torch import nn
|
||||
from transformers import OPTConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import OPTCacheFlowAttention
|
||||
from cacheflow.models.attention import GPTCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
@@ -55,7 +55,7 @@ class OPTAttention(nn.Module):
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = OPTCacheFlowAttention(scale=self.scaling)
|
||||
self.attn = GPTCacheFlowAttention(scale=self.scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user