use flash-attn via xformers (#877)

This commit is contained in:
Aman Gupta Karmani
2023-08-30 00:52:13 -04:00
committed by GitHub
parent d2b2eed67c
commit 75471386de
2 changed files with 0 additions and 5 deletions

View File

@@ -266,7 +266,6 @@ def run_multi_query_kv_attention(
qkv.uniform_(-1e-3, 1e-3)
query, key, value = qkv.unbind(dim=1)
attn_op = xops.fmha.cutlass.FwOp()
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
@@ -275,7 +274,6 @@ def run_multi_query_kv_attention(
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
output = output.squeeze(0)