Add support for GPT-NeoX (Pythia) (#50)
This commit is contained in:
@@ -34,6 +34,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.rotary_dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# Create cos and sin embeddings.
|
||||
@@ -52,13 +53,24 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
|
||||
|
||||
query_rot = query_rot.transpose(0, 1)
|
||||
key_rot = key_rot.transpose(0, 1)
|
||||
cos = F.embedding(positions, self.cos_cached)
|
||||
sin = F.embedding(positions, self.sin_cached)
|
||||
query = query.transpose(0, 1)
|
||||
key = key.transpose(0, 1)
|
||||
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
||||
query = query.transpose(0, 1).contiguous()
|
||||
key = key.transpose(0, 1).contiguous()
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
query_rot = query_rot.transpose(0, 1).contiguous()
|
||||
key_rot = key_rot.transpose(0, 1).contiguous()
|
||||
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
# Output query/key shape: [num_tokens, num_tokens, head_size]
|
||||
return query, key
|
||||
|
||||
@@ -69,6 +81,7 @@ def test_rotary_embedding_neox(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
max_position: int,
|
||||
rotary_dim: int,
|
||||
dtype: torch.dtype,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
@@ -77,7 +90,7 @@ def test_rotary_embedding_neox(
|
||||
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
|
||||
|
||||
# Create the rotary embedding.
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
@@ -92,12 +105,13 @@ def test_rotary_embedding_neox(
|
||||
positions,
|
||||
out_query,
|
||||
out_key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
)
|
||||
|
||||
# Run the reference implementation.
|
||||
ref_rotary_embedding = RefRotaryEmbeddingNeox(
|
||||
dim=head_size,
|
||||
dim=rotary_dim,
|
||||
max_position_embeddings=max_position,
|
||||
base=base,
|
||||
).to(dtype=dtype, device='cuda')
|
||||
@@ -123,5 +137,6 @@ if __name__ == '__main__':
|
||||
num_heads=5,
|
||||
head_size=head_size,
|
||||
max_position=8192,
|
||||
rotary_dim=int(head_size * 0.25),
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user