Refactor system architecture (#82)
This commit is contained in:
12
cacheflow/model_executor/models/__init__.py
Normal file
12
cacheflow/model_executor/models/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from cacheflow.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from cacheflow.model_executor.models.gpt2 import GPT2LMHeadModel
|
||||
from cacheflow.model_executor.models.llama import LlamaForCausalLM
|
||||
from cacheflow.model_executor.models.opt import OPTForCausalLM
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GPT2LMHeadModel",
|
||||
"GPTNeoXForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
"OPTForCausalLM",
|
||||
]
|
||||
261
cacheflow/model_executor/models/gpt2.py
Normal file
261
cacheflow/model_executor/models/gpt2.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""1D GPT-2 model compatible with HuggingFace weights."""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2Config
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = GPTCacheFlowAttention(scale=self.scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class GPT2MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: GPT2Config,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
|
||||
bias=True, gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
act_fn = config.activation_function
|
||||
if act_fn != "gelu_new":
|
||||
raise ValueError(f"Unsupported activation: {act_fn}. "
|
||||
"GPT-2 only supports gelu_new for now.")
|
||||
self.act = torch.nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.c_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPT2Block(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(config)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(inner_dim, config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
# residual connection
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPT2Model(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.add_cross_attention == False
|
||||
assert config.scale_attn_by_inverse_layer_idx == False
|
||||
assert config.reorder_and_upcast_attn == False
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
|
||||
# to 50304 in order to make it divisible by 64.
|
||||
# This improves performance since GPUs are faster if the dimension
|
||||
# is divisible by 64. In addition, it allows us to shard the embedding
|
||||
# layer across 2, 4, 8, or more GPUs.
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[GPT2Block(config) for _ in range(config.num_hidden_layers)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
for i in range(len(self.h)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPT2LMHeadModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = GPT2Model(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.transformer(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
continue
|
||||
if ".attn.bias" in name:
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
name = "transformer." + name
|
||||
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
param = state_dict[name]
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along the head dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
head_start = tensor_model_parallel_rank * num_heads
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif name.endswith(".bias"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected parameter name {name}")
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
231
cacheflow/model_executor/models/gpt_neox.py
Normal file
231
cacheflow/model_executor/models/gpt_neox.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""1D GPT-NeoX model compatible with HuggingFace weights."""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTNeoXConfig
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.total_num_heads
|
||||
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
|
||||
self.query_key_value = ColumnParallelLinear(config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
scaling = self.head_size ** -0.5
|
||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||
assert rotary_dim % 2 == 0
|
||||
self.attn = GPTNeoXCacheFlowAttention(scaling, rotary_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.LongTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class GPTNeoXMLP(nn.Module):
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if config.hidden_act != 'gelu':
|
||||
raise ValueError(f'Unsupported activation: {config.hidden_act}. '
|
||||
'Only gelu is supported for now.')
|
||||
self.act = torch.nn.GELU()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.dense_4h_to_h(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTNeoXLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = GPTNeoXAttention(config)
|
||||
self.mlp = GPTNeoXMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.LongTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.input_layernorm(hidden_states)
|
||||
attn_output = self.attention(
|
||||
position_ids=position_ids,
|
||||
hidden_states=attn_input,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
if self.use_parallel_residual:
|
||||
# pseudocode:
|
||||
# x = x + attn(ln1(x)) + mlp(ln2(x))
|
||||
mlp_input = self.post_attention_layernorm(hidden_states)
|
||||
mlp_output = self.mlp(mlp_input)
|
||||
hidden_states = mlp_output + attn_output + hidden_states
|
||||
else:
|
||||
# pseudocode:
|
||||
# x = x + attn(ln1(x))
|
||||
# x = x + mlp(ln2(x))
|
||||
attn_output = attn_output + hidden_states
|
||||
mlp_input = self.post_attention_layernorm(attn_output)
|
||||
mlp_output = self.mlp(mlp_input)
|
||||
hidden_states = mlp_output + attn_output
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTNeoXModel(nn.Module):
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTNeoXForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.gpt_neox = GPTNeoXModel(config)
|
||||
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size,
|
||||
bias=False, gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.gpt_neox(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.embed_out.weight, hidden_states, input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
|
||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
continue
|
||||
param = state_dict[name]
|
||||
if "query_key_value" in name:
|
||||
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
|
||||
# [num_heads * 3 * head_size, hidden_size], while the
|
||||
# required shape is [3 * num_heads * head_size, hidden_size].
|
||||
# Thus, we need weight conversion.
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
|
||||
num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // num_heads
|
||||
if 'query_key_value.weight' in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif 'query_key_value.bias' in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected weight name: {name}")
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
267
cacheflow/model_executor/models/llama.py
Normal file
267
cacheflow/model_executor/models/llama.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""1D LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import SiluAndMul
|
||||
from cacheflow.model_executor.layers.layernorm import RMSNorm
|
||||
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
|
||||
bias=False, gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=False, input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if hidden_act != 'silu':
|
||||
raise ValueError(f'Unsupported activation: {hidden_act}. '
|
||||
'Only silu is supported for now.')
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
3 * self.total_num_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = LlamaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = LlamaModel(config)
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head.weight, hidden_states, input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
|
||||
"qkv_proj.weight", "gate_proj.weight",
|
||||
"up_proj.weight"]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
291
cacheflow/model_executor/models/opt.py
Normal file
291
cacheflow/model_executor/models/opt.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""1D OPT model compatible with HuggingFace weights."""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import OPTConfig
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models don't have this hack
|
||||
self.offset = 2
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim)
|
||||
|
||||
def forward(self, positions: torch.LongTensor):
|
||||
return super().forward(positions + self.offset)
|
||||
|
||||
|
||||
class OPTAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
total_num_heads = num_heads
|
||||
assert num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = embed_dim // total_num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = GPTCacheFlowAttention(scale=self.scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class OPTDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = OPTAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
bias=config.enable_bias,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
assert config.activation_function == 'relu'
|
||||
self.activation_fn = nn.ReLU()
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event)
|
||||
hidden_states = residual + hidden_states
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OPTDecoder(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.word_embed_proj_dim,
|
||||
perform_initialization=False)
|
||||
# Positional embeddings are replicated (not sharded).
|
||||
self.embed_positions = OPTLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# Project out & in will be replicated if they exist.
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
|
||||
else:
|
||||
self.project_out = None
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
|
||||
)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
pos_embeds = self.embed_positions(positions)
|
||||
if self.project_in is not None:
|
||||
inputs_embeds = self.project_in(inputs_embeds)
|
||||
hidden_states = inputs_embeds + pos_embeds
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OPTModel(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
super().__init__()
|
||||
self.decoder = OPTDecoder(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
return self.decoder(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
|
||||
|
||||
class OPTForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = OPTModel(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
|
||||
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
|
||||
if name.startswith("decoder."):
|
||||
name = "model." + name
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
Reference in New Issue
Block a user