Add support for LLaMA-2 (#505)
This commit is contained in:
@@ -84,21 +84,26 @@ class LlamaAttention(nn.Module):
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
3 * self.total_num_heads * self.head_dim,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
@@ -113,7 +118,8 @@ class LlamaAttention(nn.Module):
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim)
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -124,7 +130,7 @@ class LlamaAttention(nn.Module):
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
@@ -140,6 +146,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
self.self_attn = LlamaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -259,9 +266,19 @@ class LlamaForCausalLM(nn.Module):
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
kv_proj_shard_size = (self.config.hidden_size //
|
||||
self.config.num_attention_heads *
|
||||
self.config.num_key_value_heads // tp_size)
|
||||
attention_weight_specs = [
|
||||
# (weight_name, shard_size, offset)
|
||||
("q_proj", q_proj_shard_size, 0),
|
||||
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||
("v_proj", kv_proj_shard_size,
|
||||
q_proj_shard_size + kv_proj_shard_size),
|
||||
]
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
@@ -272,8 +289,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
param = state_dict[name]
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = (param.shape[0] *
|
||||
tensor_model_parallel_world_size)
|
||||
padded_vocab_size = (param.shape[0] * tp_size)
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
@@ -281,18 +297,17 @@ class LlamaForCausalLM(nn.Module):
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
param_slice = param.data[offset:offset + shard_size]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user