use flash-attn via xformers (#877)
This commit is contained in:
committed by
GitHub
parent
d2b2eed67c
commit
75471386de
@@ -266,7 +266,6 @@ def run_multi_query_kv_attention(
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
|
||||
attn_op = xops.fmha.cutlass.FwOp()
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||
output = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
@@ -275,7 +274,6 @@ def run_multi_query_kv_attention(
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
output = output.squeeze(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user