use flash-attn via xformers (#877)
This commit is contained in:
committed by
GitHub
parent
d2b2eed67c
commit
75471386de
@@ -61,7 +61,6 @@ class PagedAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
@@ -115,7 +114,6 @@ class PagedAttention(nn.Module):
|
||||
attn_bias=input_metadata.attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.squeeze(0))
|
||||
@@ -404,7 +402,6 @@ class PagedAttentionWithALiBi(PagedAttention):
|
||||
attn_bias=input_metadata.attn_bias[i],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.squeeze(0))
|
||||
|
||||
Reference in New Issue
Block a user