Use FP32 in RoPE initialization (#1004)

Co-authored-by: One <imone@tuta.io>
This commit is contained in:
Woosuk Kwon
2023-09-11 00:26:35 -07:00
committed by GitHub
parent d6770d1f23
commit e67b4f2c2a
2 changed files with 7 additions and 6 deletions

View File

@@ -133,9 +133,10 @@ def test_rotary_embedding(
device="cuda")
# Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
inv_freq = 1.0 / (base**(
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)