Fix for breaking changes in xformers 0.0.21 (#834)

This commit is contained in:
Woosuk Kwon
2023-08-23 17:44:21 +09:00
committed by GitHub
parent 85ebcda94d
commit 2a4ec90854
2 changed files with 4 additions and 3 deletions

View File

@@ -357,11 +357,12 @@ class PagedAttentionWithALiBi(PagedAttention):
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (prompt_len + 7) // 8 * 8
bias = torch.empty(
1, # batch_size
self.num_heads,
padded_len,
prompt_len,
padded_len,
device=self.alibi_slopes.device,
)[:, :prompt_len, :prompt_len].copy_(bias)
)[:, :, :, :prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias)
input_metadata.attn_bias.append(attn_bias)