use flash-attn via xformers (#877)

This commit is contained in:
Aman Gupta Karmani
2023-08-30 00:52:13 -04:00
committed by GitHub
parent d2b2eed67c
commit 75471386de
2 changed files with 0 additions and 5 deletions

View File

@@ -61,7 +61,6 @@ class PagedAttention(nn.Module):
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
@@ -115,7 +114,6 @@ class PagedAttention(nn.Module):
attn_bias=input_metadata.attn_bias[0],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0))
@@ -404,7 +402,6 @@ class PagedAttentionWithALiBi(PagedAttention):
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))