Add support for GPT-NeoX (Pythia) (#50)

This commit is contained in:
Woosuk Kwon
2023-04-28 00:32:10 -07:00
committed by GitHub
parent aa50b17ca7
commit a96d63c21d
9 changed files with 436 additions and 71 deletions

View File

@@ -150,20 +150,20 @@ class OPTCacheFlowAttention(GPTCacheFlowAttention):
super().__init__(scale)
class LlamaCacheFlowAttention(GPTCacheFlowAttention):
"""Llama uses GPT-NeoX style rotary embedding."""
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
"""Attention with GPT-NeoX style rotary embedding."""
def __init__(
self,
scale: float,
head_size: int,
rotary_dim: int,
max_position: int = 8192,
base: int = 10000,
) -> None:
super().__init__(scale)
# Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
@@ -174,7 +174,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
# initializing the model. Make it more robust.
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, head_size]
# Embedding size: [max_position, rotary_dim]
self.register_buffer('cos_sin_cache', cache, persistent=False)
def forward(
@@ -190,10 +190,12 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
head_size = value_cache.shape[2]
pos_encoding_ops.rotary_embedding_neox(
positions,
query,
key,
head_size,
self.cos_sin_cache,
)
return super().forward(
@@ -205,3 +207,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
input_metadata,
cache_event,
)
class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
"""LLaMA uses the GPT-NeoX style rotary embedding."""