Replace FlashAttention with xformers (#70)

This commit is contained in:
Woosuk Kwon
2023-05-05 02:01:08 -07:00
committed by GitHub
parent 189ae23133
commit c9d5b6d4a8
13 changed files with 89 additions and 133 deletions

View File

@@ -23,7 +23,7 @@ def test_silu_and_mul(
if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')