Refactor llama family models (#2637)

This commit is contained in:
Roy
2024-02-13 16:09:23 +08:00
committed by GitHub
parent f964493274
commit 5c976a7e1a
17 changed files with 236 additions and 2720 deletions

View File

@@ -21,8 +21,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import math
import torch
from torch import nn
from transformers import LlamaConfig
@@ -40,34 +41,60 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig
from copy import deepcopy
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
config.hidden_size, [config.intermediate_size] * 2,
bias=False,
linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
self.down_proj = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=False,
linear_method=linear_method)
hidden_act = getattr(config, "hidden_act", "silu")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -84,21 +111,19 @@ class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
self.total_num_heads = getattr(config, "num_attention_heads", None)
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
# defaut to mha
self.total_num_kv_heads = getattr(config, "num_key_value_heads",
self.total_num_heads)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
@@ -108,39 +133,68 @@ class LlamaAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.max_position_embeddings = config.max_position_embeddings
# internlm
bias = getattr(config, "bias", False)
# stablelm
qkv_bias = getattr(config, "use_qkv_bias", False)
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
bias=bias or qkv_bias,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
self.hidden_size,
bias=bias,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
# mistral
sliding_window = getattr(config, "sliding_window", None)
self.postion_embedding = getattr(config, "postion_embedding", "ROPE")
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
alibi_slopes=alibi_slopes,
sliding_window=sliding_window)
else:
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
# stablelm
rope_pct = getattr(config, "rope_pct", 1)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=int(self.head_dim * rope_pct),
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)
def forward(
self,
@@ -151,7 +205,8 @@ class LlamaAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
@@ -164,32 +219,20 @@ class LlamaDecoderLayer(nn.Module):
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
norm: Optional[torch.Tensor] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
config,
linear_method=linear_method,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
config,
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = deepcopy(norm)
self.post_attention_layernorm = deepcopy(norm)
def forward(
self,
@@ -226,6 +269,7 @@ class LlamaModel(nn.Module):
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
norm: Optional[torch.Tensor] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
@@ -241,10 +285,10 @@ class LlamaModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config, linear_method)
LlamaDecoderLayer(config, linear_method, norm)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = norm
def forward(
self,
@@ -275,12 +319,18 @@ class LlamaForCausalLM(nn.Module):
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
norm: Optional[torch.Tensor] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
if norm is None:
norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.model = LlamaModel(config,
linear_method,
norm=norm,
lora_config=lora_config)
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size