Add unoptimized OPT Attention

This commit is contained in:
Woosuk Kwon
2023-02-23 09:31:55 +00:00
parent b56b6ca0d6
commit d4bc1a4d24
2 changed files with 177 additions and 14 deletions

View File

@@ -1,9 +1,17 @@
"""1D OPT model compatible with HuggingFace weights."""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import OPTConfig
from transformers import PreTrainedModel
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding):
@@ -31,17 +39,27 @@ class OPTAttention(nn.Module):
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim**-0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
q = self.q_proj(hidden_states) * self.scaling
self.attn = OPTCacheFlowAttention(scale=self.scaling)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# TODO
attn_output = None
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output = self.out_proj(attn_output)
return output
@@ -66,13 +84,23 @@ class OPTDecoderLayer(nn.Module):
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
@@ -145,6 +173,9 @@ class OPTDecoder(OPTPreTrainedModel):
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions)
@@ -153,8 +184,14 @@ class OPTDecoder(OPTPreTrainedModel):
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
for layer in self.layers:
hidden_states = layer(hidden_states)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
@@ -175,8 +212,12 @@ class OPTModel(OPTPreTrainedModel):
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
return self.decoder(input_ids, positions)
return self.decoder(
input_ids, positions, kv_caches, input_metadata, cache_events)
class OPTForCausalLM(OPTPreTrainedModel):
@@ -185,9 +226,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = OPTModel(config)
# the lm_head weight is automatically tied to the embed tokens weight
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
self.sampler = Sampler(embedding=self.lm_head.weight)
# Initialize weights and apply final processing
self.post_init()
@@ -196,7 +237,11 @@ class OPTForCausalLM(OPTPreTrainedModel):
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
) -> torch.Tensor:
hidden_states = self.model.decoder(input_ids, positions)
logits = self.lm_head(hidden_states).contiguous()
return logits
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, Tuple[int, int]]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(hidden_states, input_metadata)
return next_tokens