Replace FlashAttention with xformers (#70)
This commit is contained in:
@@ -129,7 +129,7 @@ def test_rotary_embedding_neox(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Running tests for head_size={head_size} and dtype={dtype}')
|
||||
test_rotary_embedding_neox(
|
||||
|
||||
Reference in New Issue
Block a user