[Quality] Add code formatter and linter (#326)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -30,7 +31,6 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -56,15 +56,19 @@ class LlamaMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
|
||||
bias=False, gather_output=False,
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=False, input_is_parallel=True,
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if hidden_act != 'silu':
|
||||
raise ValueError(f'Unsupported activation: {hidden_act}. '
|
||||
'Only silu is supported for now.')
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
@@ -83,12 +87,14 @@ class LlamaAttention(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
@@ -104,8 +110,10 @@ class LlamaAttention(nn.Module):
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim,
|
||||
self.scaling, rotary_dim=self.head_dim)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -118,8 +126,8 @@ class LlamaAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -138,8 +146,10 @@ class LlamaDecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -177,9 +187,13 @@ class LlamaModel(nn.Module):
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
@@ -209,6 +223,7 @@ class LlamaModel(nn.Module):
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -228,39 +243,42 @@ class LlamaForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head.weight, hidden_states, input_metadata)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
|
||||
"qkv_proj.weight", "gate_proj.weight",
|
||||
"up_proj.weight"]
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
|
||||
"gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
@@ -275,10 +293,10 @@ class LlamaForCausalLM(nn.Module):
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
|
||||
Reference in New Issue
Block a user