Add support for GPT-NeoX (Pythia) (#50)
This commit is contained in:
@@ -150,20 +150,20 @@ class OPTCacheFlowAttention(GPTCacheFlowAttention):
|
||||
super().__init__(scale)
|
||||
|
||||
|
||||
class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
||||
"""Llama uses GPT-NeoX style rotary embedding."""
|
||||
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
"""Attention with GPT-NeoX style rotary embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale: float,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
super().__init__(scale)
|
||||
|
||||
# Create the cos and sin cache.
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
@@ -174,7 +174,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
||||
# initializing the model. Make it more robust.
|
||||
torch_dtype = torch.get_default_dtype()
|
||||
cache = cache.to(torch_dtype)
|
||||
# Embedding size: [max_position, head_size]
|
||||
# Embedding size: [max_position, rotary_dim]
|
||||
self.register_buffer('cos_sin_cache', cache, persistent=False)
|
||||
|
||||
def forward(
|
||||
@@ -190,10 +190,12 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
head_size = value_cache.shape[2]
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
self.cos_sin_cache,
|
||||
)
|
||||
return super().forward(
|
||||
@@ -205,3 +207,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
|
||||
|
||||
class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
|
||||
"""LLaMA uses the GPT-NeoX style rotary embedding."""
|
||||
|
||||
Reference in New Issue
Block a user