Refactor llama family models (#2637)
This commit is contained in:
@@ -21,8 +21,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
@@ -40,34 +41,60 @@ from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.config import LoRAConfig
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2,
|
||||
total_num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(start=1,
|
||||
end=1 + 2 * num_remaining_heads,
|
||||
step=2,
|
||||
dtype=torch.int32)
|
||||
slopes = torch.cat(
|
||||
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
config.hidden_size, [config.intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
self.down_proj = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
hidden_act = getattr(config, "hidden_act", "silu")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@@ -84,21 +111,19 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
self.total_num_heads = getattr(config, "num_attention_heads", None)
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
|
||||
# defaut to mha
|
||||
self.total_num_kv_heads = getattr(config, "num_key_value_heads",
|
||||
self.total_num_heads)
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
@@ -108,39 +133,68 @@ class LlamaAttention(nn.Module):
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
# internlm
|
||||
bias = getattr(config, "bias", False)
|
||||
|
||||
# stablelm
|
||||
qkv_bias = getattr(config, "use_qkv_bias", False)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
bias=bias or qkv_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
self.hidden_size,
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
# mistral
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
|
||||
self.postion_embedding = getattr(config, "postion_embedding", "ROPE")
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.postion_embedding == "ALIBI":
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sliding_window=sliding_window)
|
||||
else:
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
# stablelm
|
||||
rope_pct = getattr(config, "rope_pct", 1)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=int(self.head_dim * rope_pct),
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -151,7 +205,8 @@ class LlamaAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
if self.postion_embedding != "ALIBI":
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
@@ -164,32 +219,20 @@ class LlamaDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
norm: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = LlamaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = deepcopy(norm)
|
||||
self.post_attention_layernorm = deepcopy(norm)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -226,6 +269,7 @@ class LlamaModel(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
norm: Optional[torch.Tensor] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -241,10 +285,10 @@ class LlamaModel(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(config, linear_method)
|
||||
LlamaDecoderLayer(config, linear_method, norm)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -275,12 +319,18 @@ class LlamaForCausalLM(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
norm: Optional[torch.Tensor] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
|
||||
if norm is None:
|
||||
norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.model = LlamaModel(config,
|
||||
linear_method,
|
||||
norm=norm,
|
||||
lora_config=lora_config)
|
||||
unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
Reference in New Issue
Block a user