Replace FlashAttention with xformers (#70)
This commit is contained in:
@@ -8,7 +8,8 @@ class RefRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
weight = torch.randn(hidden_size) / (hidden_size ** 0.5)
|
||||
weight = torch.empty(hidden_size)
|
||||
weight.uniform_(-1e-3, 1e-3)
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
@@ -41,7 +42,7 @@ def test_rms_norm(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for num_tokens in [7, 128, 2048]:
|
||||
for hidden_size in [13, 64, 1024, 5120]:
|
||||
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
|
||||
|
||||
Reference in New Issue
Block a user