Replace FlashAttention with xformers (#70)

This commit is contained in:
Woosuk Kwon
2023-05-05 02:01:08 -07:00
committed by GitHub
parent 189ae23133
commit c9d5b6d4a8
13 changed files with 89 additions and 133 deletions

View File

@@ -8,7 +8,8 @@ class RefRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
weight = torch.randn(hidden_size) / (hidden_size ** 0.5)
weight = torch.empty(hidden_size)
weight.uniform_(-1e-3, 1e-3)
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
@@ -41,7 +42,7 @@ def test_rms_norm(
if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 128, 2048]:
for hidden_size in [13, 64, 1024, 5120]:
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='