Replace FlashAttention with xformers (#70)

This commit is contained in:
Woosuk Kwon
2023-05-05 02:01:08 -07:00
committed by GitHub
parent 189ae23133
commit c9d5b6d4a8
13 changed files with 89 additions and 133 deletions

View File

@@ -129,7 +129,7 @@ def test_rotary_embedding_neox(
if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Running tests for head_size={head_size} and dtype={dtype}')
test_rotary_embedding_neox(