Refactor system architecture (#82)

This commit is contained in:
Woosuk Kwon
2023-05-09 15:30:12 -07:00
committed by GitHub
parent 8917782af6
commit 7c041ab578
40 changed files with 194 additions and 446 deletions

View File

@@ -0,0 +1,12 @@
from cacheflow.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from cacheflow.model_executor.models.gpt2 import GPT2LMHeadModel
from cacheflow.model_executor.models.llama import LlamaForCausalLM
from cacheflow.model_executor.models.opt import OPTForCausalLM
__all__ = [
"GPT2LMHeadModel",
"GPTNeoXForCausalLM",
"LlamaForCausalLM",
"OPTForCausalLM",
]

View File

@@ -0,0 +1,261 @@
"""1D GPT-2 model compatible with HuggingFace weights."""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import GPT2Config
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from cacheflow.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPT2Attention(nn.Module):
def __init__(self, config: GPT2Config):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim ** -0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True,
input_is_parallel=True,
perform_initialization=False)
self.attn = GPTCacheFlowAttention(scale=self.scale)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output)
return attn_output
class GPT2MLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: GPT2Config,
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
bias=True, gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=True, input_is_parallel=True,
perform_initialization=False)
act_fn = config.activation_function
if act_fn != "gelu_new":
raise ValueError(f"Unsupported activation: {act_fn}. "
"GPT-2 only supports gelu_new for now.")
self.act = torch.nn.GELU(approximate="tanh")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class GPT2Block(nn.Module):
def __init__(self, config: GPT2Config):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
class GPT2Model(nn.Module):
def __init__(self, config: GPT2Config):
super().__init__()
self.config = config
assert config.add_cross_attention == False
assert config.scale_attn_by_inverse_layer_idx == False
assert config.reorder_and_upcast_attn == False
self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
# to 50304 in order to make it divisible by 64.
# This improves performance since GPUs are faster if the dimension
# is divisible by 64. In addition, it allows us to shard the embedding
# layer across 2, 4, 8, or more GPUs.
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList(
[GPT2Block(config) for _ in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.LongTensor,
position_ids: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPT2LMHeadModel(nn.Module):
def __init__(self, config: GPT2Config):
super().__init__()
self.config = config
self.transformer = GPT2Model(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head_weight, hidden_states, input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self, model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
name = "transformer." + name
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
param = state_dict[name]
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along the head dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -0,0 +1,231 @@
"""1D GPT-NeoX model compatible with HuggingFace weights."""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import GPTNeoXConfig
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from cacheflow.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTNeoXAttention(nn.Module):
def __init__(self, config: GPTNeoXConfig):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.query_key_value = ColumnParallelLinear(config.hidden_size,
3 * config.hidden_size,
gather_output=False,
perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
scaling = self.head_size ** -0.5
rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0
self.attn = GPTNeoXCacheFlowAttention(scaling, rotary_dim)
def forward(
self,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event)
output, _ = self.dense(attn_output)
return output
class GPTNeoXMLP(nn.Module):
def __init__(self, config: GPTNeoXConfig):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
gather_output=False,
perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
if config.hidden_act != 'gelu':
raise ValueError(f'Unsupported activation: {config.hidden_act}. '
'Only gelu is supported for now.')
self.act = torch.nn.GELU()
def forward(self, hidden_states):
hidden_states, _ = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.dense_4h_to_h(hidden_states)
return hidden_states
class GPTNeoXLayer(nn.Module):
def __init__(self, config: GPTNeoXConfig):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config)
self.mlp = GPTNeoXMLP(config)
def forward(
self,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
attn_input = self.input_layernorm(hidden_states)
attn_output = self.attention(
position_ids=position_ids,
hidden_states=attn_input,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_input = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(mlp_input)
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_input = self.post_attention_layernorm(attn_output)
mlp_output = self.mlp(mlp_input)
hidden_states = mlp_output + attn_output
return hidden_states
class GPTNeoXModel(nn.Module):
def __init__(self, config: GPTNeoXConfig):
super().__init__()
self.config = config
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
input_ids: torch.LongTensor,
position_ids: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class GPTNeoXForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.gpt_neox = GPTNeoXModel(config)
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size,
bias=False, gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.gpt_neox(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.embed_out.weight, hidden_states, input_metadata)
return next_tokens
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
continue
param = state_dict[name]
if "query_key_value" in name:
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
# [num_heads * 3 * head_size, hidden_size], while the
# required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion.
shard_size = param.shape[0]
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads
if 'query_key_value.weight' in name:
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif 'query_key_value.bias' in name:
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -0,0 +1,267 @@
"""1D LLaMA model compatible with HuggingFace weights."""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from cacheflow.sequence import SequenceOutputs
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.layers.activation import SiluAndMul
from cacheflow.model_executor.layers.layernorm import RMSNorm
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from cacheflow.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
perform_initialization=False)
if hidden_act != 'silu':
raise ValueError(f'Unsupported activation: {hidden_act}. '
'Only silu is supported for now.')
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5
self.qkv_proj = ColumnParallelLinear(
hidden_size,
3 * self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
def forward(
self,
positions: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class LlamaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = LlamaModel(config)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head.weight, hidden_states, input_metadata)
return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
"qkv_proj.weight", "gate_proj.weight",
"up_proj.weight"]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@@ -0,0 +1,291 @@
"""1D OPT model compatible with HuggingFace weights."""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import OPTConfig
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from cacheflow.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, positions: torch.LongTensor):
return super().forward(positions + self.offset)
class OPTAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
) -> None:
super().__init__()
self.embed_dim = embed_dim
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
total_num_heads = num_heads
assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim ** -0.5
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.attn = GPTCacheFlowAttention(scale=self.scaling)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
bias=config.enable_bias,
)
self.do_layer_norm_before = config.do_layer_norm_before
assert config.activation_function == 'relu'
self.activation_fn = nn.ReLU()
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
bias=config.enable_bias,
gather_output=False,
perform_initialization=False)
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
bias=config.enable_bias,
input_is_parallel=True,
perform_initialization=False)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class OPTDecoder(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.word_embed_proj_dim,
perform_initialization=False)
# Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size)
# Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
return hidden_states
class OPTModel(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.decoder = OPTDecoder(config)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
return self.decoder(
input_ids, positions, kv_caches, input_metadata, cache_events)
class OPTForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = OPTModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head_weight, hidden_states, input_metadata)
return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name:
continue
if name.startswith("decoder."):
name = "model." + name
is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)