[Quality] Add code formatter and linter (#326)
This commit is contained in:
@@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InputMetadata",
|
||||
"get_model",
|
||||
|
||||
@@ -8,11 +8,22 @@ from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
class InputMetadata:
|
||||
"""Metadata for input sequences. Used for PagedAttention.
|
||||
|
||||
Args:
|
||||
seq_groups: List of (seq_ids, sampling_params).
|
||||
seq_data: Seq_id -> SequenceData.
|
||||
prompt_lens: Lengths of prompts.
|
||||
slot_mapping: The address to write the new KV to of each token.
|
||||
context_lens: the length of attention context for each generation token.
|
||||
max_context_len: The maximum context length.
|
||||
block_tables: The block tables. (Seq id -> list of physical block)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params).
|
||||
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData.
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
seq_data: Dict[int, SequenceData],
|
||||
prompt_lens: List[int],
|
||||
slot_mapping: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
|
||||
@@ -6,9 +6,10 @@ from vllm import activation_ops
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
|
||||
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
|
||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
|
||||
# NOTE: The following GELU functions may introduce small rounding errors.
|
||||
"gelu_new": nn.GELU(approximate="tanh"),
|
||||
"gelu_fast": nn.GELU(approximate="tanh"),
|
||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||
"relu": nn.ReLU(),
|
||||
}
|
||||
|
||||
@@ -25,15 +26,13 @@ class SiluAndMul(nn.Module):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d)
|
||||
return: (num_tokens, d)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (num_tokens, 2 * d)
|
||||
) -> torch.Tensor: # (num_tokens, d)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.shape[0]
|
||||
d = x.shape[1] // 2
|
||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||
|
||||
@@ -14,6 +14,7 @@ _SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
|
||||
|
||||
|
||||
class PagedAttention(nn.Module):
|
||||
# pylint: disable=line-too-long
|
||||
"""GPT-style multi-head PagedAttention.
|
||||
|
||||
This class takes flattened 1D query, key, and value tensors as input. The
|
||||
@@ -54,12 +55,20 @@ class PagedAttention(nn.Module):
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: xops.AttentionBias,
|
||||
) -> torch.Tensor:
|
||||
"""Normal attention for the prompt tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
key: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
value: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
"""
|
||||
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
@@ -76,12 +85,22 @@ class PagedAttention(nn.Module):
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
"""PagedAttention for the generation tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_generation_tokens, num_heads, head_size]
|
||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
@@ -97,16 +116,32 @@ class PagedAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size]
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
# NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||
) -> torch.Tensor:
|
||||
"""PagedAttention forward pass.
|
||||
|
||||
NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||
tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_heads * head_size]
|
||||
value: shape = [num_tokens, num_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
@@ -136,7 +171,7 @@ class PagedAttention(nn.Module):
|
||||
# and value vectors will not be cached.
|
||||
num_valid_tokens = input_metadata.num_valid_tokens
|
||||
if (num_valid_tokens > 0 and key_cache is not None
|
||||
and value_cache is not None):
|
||||
and value_cache is not None):
|
||||
# The stride is 3 because the key and value are sliced from qkv.
|
||||
cache_ops.reshape_and_cache(
|
||||
key[:num_valid_tokens],
|
||||
@@ -149,15 +184,12 @@ class PagedAttention(nn.Module):
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
assert key_cache is not None and value_cache is not None, (
|
||||
"key_cache and value_cache must be provided when "
|
||||
"generating tokens."
|
||||
)
|
||||
"generating tokens.")
|
||||
# Compute the attention op for generation tokens.
|
||||
self.single_query_cached_kv_attention(
|
||||
output[num_prompt_tokens:num_valid_tokens],
|
||||
query[num_prompt_tokens:num_valid_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata)
|
||||
query[num_prompt_tokens:num_valid_tokens], key_cache,
|
||||
value_cache, input_metadata)
|
||||
|
||||
# Reshape the output tensor.
|
||||
# NOTE(woosuk): The output tensor may include paddings.
|
||||
@@ -179,9 +211,9 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
super().__init__(num_heads, head_size, scale)
|
||||
|
||||
# Create the cos and sin cache.
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
@@ -195,15 +227,32 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
) -> torch.Tensor:
|
||||
""" PagedAttention forward pass with rotary embedding.
|
||||
|
||||
Args:
|
||||
positions: shape = [num_tokens]
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_heads * head_size]
|
||||
value: shape = [num_tokens, num_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.sequence import SequenceOutputs
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
"""Samples the next tokens from the model's outputs.
|
||||
|
||||
@@ -50,19 +51,20 @@ class Sampler(nn.Module):
|
||||
# Apply presence and frequency penalties.
|
||||
output_tokens = _get_output_tokens(input_metadata)
|
||||
assert len(output_tokens) == logits.shape[0]
|
||||
presence_penalties, frequency_penalties = _get_penalties(input_metadata)
|
||||
presence_penalties, frequency_penalties = _get_penalties(
|
||||
input_metadata)
|
||||
assert len(presence_penalties) == logits.shape[0]
|
||||
assert len(frequency_penalties) == logits.shape[0]
|
||||
logits = _apply_penalties(
|
||||
logits, output_tokens, presence_penalties, frequency_penalties,
|
||||
self.vocab_size)
|
||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
||||
frequency_penalties, self.vocab_size)
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
assert len(temperatures) == logits.shape[0]
|
||||
if any(t != 1.0 for t in temperatures):
|
||||
t = torch.tensor(
|
||||
temperatures, dtype=logits.dtype, device=logits.device)
|
||||
t = torch.tensor(temperatures,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits.div_(t.unsqueeze(dim=1))
|
||||
|
||||
@@ -75,7 +77,9 @@ class Sampler(nn.Module):
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||
if any(p < 1.0 - _SAMPLING_EPS for p in top_ps) or any(k != self.vocab_size for k in top_ks):
|
||||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
||||
do_top_k = any(k != self.vocab_size for k in top_ks)
|
||||
if do_top_p or do_top_k:
|
||||
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
||||
|
||||
# Sample the next tokens.
|
||||
@@ -97,8 +101,7 @@ def _prune_hidden_states(
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
input_metadata: InputMetadata,
|
||||
) -> Tuple[List[float], List[float]]:
|
||||
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
|
||||
# Collect the presence and frequency penalties.
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
@@ -117,9 +120,7 @@ def _get_penalties(
|
||||
return presence_penalties, frequency_penalties
|
||||
|
||||
|
||||
def _get_output_tokens(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[List[int]]:
|
||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||
output_tokens: List[List[int]] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, _ = seq_group
|
||||
@@ -169,11 +170,13 @@ def _apply_penalties(
|
||||
device=logits.device)
|
||||
|
||||
frequency_penalties = [frequency_penalties[i] for i in indices]
|
||||
frequency_penalties = torch.tensor(
|
||||
frequency_penalties, dtype=logits.dtype, device=logits.device)
|
||||
frequency_penalties = torch.tensor(frequency_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
presence_penalties = [presence_penalties[i] for i in indices]
|
||||
presence_penalties = torch.tensor(
|
||||
presence_penalties, dtype=logits.dtype, device=logits.device)
|
||||
presence_penalties = torch.tensor(presence_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
@@ -183,9 +186,7 @@ def _apply_penalties(
|
||||
return logits
|
||||
|
||||
|
||||
def _get_temperatures(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[float]:
|
||||
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
# Collect the temperatures for the logits.
|
||||
temperatures: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
@@ -252,8 +253,9 @@ def _apply_top_p_top_k(
|
||||
probs_sort[top_k_mask] = 0.0
|
||||
|
||||
# Re-sort the probabilities.
|
||||
probs = torch.gather(
|
||||
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
|
||||
probs = torch.gather(probs_sort,
|
||||
dim=-1,
|
||||
index=torch.argsort(probs_idx, dim=-1))
|
||||
return probs
|
||||
|
||||
|
||||
@@ -296,8 +298,9 @@ def _sample_from_prompt(
|
||||
# Random sampling.
|
||||
# Sample `best_of` tokens for the prompt.
|
||||
num_seqs = sampling_params.best_of
|
||||
next_token_ids = torch.multinomial(
|
||||
prob, num_samples=num_seqs, replacement=True)
|
||||
next_token_ids = torch.multinomial(prob,
|
||||
num_samples=num_seqs,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return next_token_ids
|
||||
|
||||
@@ -315,8 +318,9 @@ def _sample_from_generation_tokens(
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
# Add cumulative logprobs for the sequences in the group.
|
||||
seq_logprobs = torch.tensor(
|
||||
seq_logprobs, dtype=torch.float, device=logprobs.device)
|
||||
seq_logprobs = torch.tensor(seq_logprobs,
|
||||
dtype=torch.float,
|
||||
device=logprobs.device)
|
||||
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
||||
|
||||
vocab_size = logprobs.size(-1)
|
||||
@@ -353,8 +357,9 @@ def _sample_from_generation_tokens(
|
||||
else:
|
||||
# Random sampling.
|
||||
# Sample 1 token for each sequence in the group.
|
||||
next_token_ids = torch.multinomial(
|
||||
probs, num_samples=1, replacement=True)
|
||||
next_token_ids = torch.multinomial(probs,
|
||||
num_samples=1,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
||||
parent_seq_ids = seq_ids
|
||||
return parent_seq_ids, next_token_ids
|
||||
@@ -381,15 +386,16 @@ def _sample(
|
||||
# Sample the next tokens.
|
||||
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs = _get_topk_logprobs(
|
||||
logprob, sampling_params.logprobs)
|
||||
next_logprobs = _get_topk_logprobs(logprob,
|
||||
sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
||||
output_logprobs = next_logprobs.copy()
|
||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
||||
seq_outputs[seq_id] = SequenceOutputs(
|
||||
seq_id, seq_id, next_token_id, output_logprobs)
|
||||
seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
|
||||
next_token_id,
|
||||
output_logprobs)
|
||||
else:
|
||||
# Generate the next tokens for generation tokens.
|
||||
prob = probs[idx:idx + len(seq_ids)]
|
||||
@@ -399,22 +405,24 @@ def _sample(
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
input_metadata.seq_data[seq_id].cumulative_logprob
|
||||
for seq_id in seq_ids]
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs: Dict[int, Dict[int, float]] = {}
|
||||
for i, seq_id in enumerate(seq_ids):
|
||||
for j, seq_id in enumerate(seq_ids):
|
||||
next_logprobs[seq_id] = _get_topk_logprobs(
|
||||
logprob[i], sampling_params.logprobs)
|
||||
logprob[j], sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for seq_id, parent_seq_id, next_token_id in zip(
|
||||
seq_ids, parent_seq_ids, next_token_ids):
|
||||
i = seq_ids.index(parent_seq_id)
|
||||
seq_ids, parent_seq_ids, next_token_ids):
|
||||
j = seq_ids.index(parent_seq_id)
|
||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
||||
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
|
||||
output_logprobs[next_token_id] = logprob[j,
|
||||
next_token_id].item()
|
||||
seq_outputs[seq_id] = SequenceOutputs(
|
||||
seq_id,
|
||||
parent_seq_id,
|
||||
|
||||
@@ -6,8 +6,9 @@ import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM, GPTNeoXForCausalLM,
|
||||
LlamaForCausalLM, OPTForCausalLM)
|
||||
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM,
|
||||
GPTNeoXForCausalLM, LlamaForCausalLM,
|
||||
OPTForCausalLM)
|
||||
from vllm.model_executor.weight_utils import initialize_dummy_weights
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
@@ -28,8 +29,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
return _MODEL_REGISTRY[arch]
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
|
||||
)
|
||||
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
|
||||
|
||||
|
||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
@@ -46,8 +46,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
initialize_dummy_weights(model)
|
||||
else:
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(
|
||||
model_config.model, model_config.download_dir,
|
||||
model_config.use_np_weights)
|
||||
model.load_weights(model_config.model, model_config.download_dir,
|
||||
model_config.use_np_weights)
|
||||
model = model.cuda()
|
||||
return model.eval()
|
||||
|
||||
@@ -4,8 +4,6 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GPT2LMHeadModel",
|
||||
"GPTBigCodeForCausalLM",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
@@ -47,19 +48,25 @@ class GPT2Attention(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size,
|
||||
bias=True, gather_output=False,
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
self.c_proj = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale)
|
||||
|
||||
def forward(
|
||||
@@ -72,8 +79,8 @@ class GPT2Attention(nn.Module):
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@@ -87,11 +94,15 @@ class GPT2MLP(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
|
||||
bias=True, gather_output=False,
|
||||
self.c_fc = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
self.c_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
@@ -107,7 +118,8 @@ class GPT2Block(nn.Module):
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(config)
|
||||
@@ -145,9 +157,9 @@ class GPT2Model(nn.Module):
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.add_cross_attention == False
|
||||
assert config.scale_attn_by_inverse_layer_idx == False
|
||||
assert config.reorder_and_upcast_attn == False
|
||||
assert not config.add_cross_attention
|
||||
assert not config.scale_attn_by_inverse_layer_idx
|
||||
assert not config.reorder_and_upcast_attn
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
|
||||
@@ -180,8 +192,8 @@ class GPT2Model(nn.Module):
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@@ -206,24 +218,26 @@ class GPT2LMHeadModel(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.transformer(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
@@ -248,16 +262,20 @@ class GPT2LMHeadModel(nn.Module):
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
|
||||
padded_vocab_size = (param.shape[0] *
|
||||
tensor_model_parallel_world_size)
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along the head dimension.
|
||||
# GPT-2's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along
|
||||
# the head dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
@@ -266,11 +284,13 @@ class GPT2LMHeadModel(nn.Module):
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif name.endswith(".bias"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 CTranslate2, and Michael Feil
|
||||
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
||||
@@ -49,19 +50,25 @@ class GPTBigCodeAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size,
|
||||
bias=True, gather_output=False,
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
self.c_proj = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale)
|
||||
|
||||
def forward(
|
||||
@@ -74,8 +81,8 @@ class GPTBigCodeAttention(nn.Module):
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@@ -89,11 +96,15 @@ class GPTBigMLP(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
|
||||
bias=True, gather_output=False,
|
||||
self.c_fc = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
self.c_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
@@ -109,7 +120,8 @@ class GPTBigCodeBlock(nn.Module):
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(config)
|
||||
@@ -147,7 +159,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.add_cross_attention == False
|
||||
assert not config.add_cross_attention
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
@@ -181,8 +193,8 @@ class GPTBigCodeModel(nn.Module):
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@@ -207,24 +219,26 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.transformer(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
@@ -241,9 +255,11 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
|
||||
padded_vocab_size = param.shape[
|
||||
0] * tensor_model_parallel_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
@@ -258,25 +274,31 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
qkv_array = qkv_array.numpy()
|
||||
|
||||
dims_q = n_head * head_dim
|
||||
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0)
|
||||
# q is fine, but k & v have not replicated shape along the first axis
|
||||
# as long as MQA is not nativly supported, increase memory and replicated
|
||||
# (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim)
|
||||
# pylint: disable=unbalanced-tuple-unpacking
|
||||
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
|
||||
axis=0)
|
||||
# q is fine, but k & v have not replicated shape along the first
|
||||
# axis as long as MQA is not nativly supported, increase memory
|
||||
# and replicated (head_dim, hidden_dim) to
|
||||
# (n_heads * head_dim, hidden_dim)
|
||||
if k.ndim == 2 and v.ndim == 2:
|
||||
replication = (n_head, 1) # weights
|
||||
else:
|
||||
replication = n_head # biases
|
||||
# replicate n_head times for q, v
|
||||
k, v = np.tile(k, replication), np.tile(v, replication)
|
||||
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim)
|
||||
# concat q, k, v along the first axis
|
||||
# (n_heads * head_dim, hidden_dim)
|
||||
# to (3 * n_heads * head_dim, hidden_dim)
|
||||
qkv_array = np.concatenate((q, k, v), axis=0)
|
||||
return torch.from_numpy(qkv_array)
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along the head dimension.
|
||||
# GPT-2's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along
|
||||
# the head dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
@@ -285,13 +307,19 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
|
||||
loaded_weight = _expand_mqa_mha(loaded_weight,
|
||||
n_head=total_num_heads,
|
||||
head_dim=head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif name.endswith(".bias"):
|
||||
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
|
||||
loaded_weight = _expand_mqa_mha(loaded_weight,
|
||||
n_head=total_num_heads,
|
||||
head_dim=head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -48,19 +49,23 @@ class GPTNeoXAttention(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.total_num_heads
|
||||
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
|
||||
self.query_key_value = ColumnParallelLinear(config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size,
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense = RowParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
scaling = self.head_size ** -0.5
|
||||
scaling = self.head_size**-0.5
|
||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||
assert rotary_dim % 2 == 0
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
|
||||
@@ -78,8 +83,8 @@ class GPTNeoXAttention(nn.Module):
|
||||
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@@ -92,7 +97,8 @@ class GPTNeoXMLP(nn.Module):
|
||||
config.intermediate_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
@@ -109,8 +115,10 @@ class GPTNeoXLayer(nn.Module):
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.attention = GPTNeoXAttention(config)
|
||||
self.mlp = GPTNeoXMLP(config)
|
||||
|
||||
@@ -154,10 +162,13 @@ class GPTNeoXModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||
self.embed_in = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.layers = nn.ModuleList(
|
||||
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -191,8 +202,10 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.gpt_neox = GPTNeoXModel(config)
|
||||
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size,
|
||||
bias=False, gather_output=False,
|
||||
self.embed_out = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@@ -204,24 +217,28 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.gpt_neox(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.embed_out.weight, hidden_states, input_metadata)
|
||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
|
||||
_column_parallel_weights = [
|
||||
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
|
||||
"dense_h_to_4h.bias"
|
||||
]
|
||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
continue
|
||||
param = state_dict[name]
|
||||
if "query_key_value" in name:
|
||||
@@ -230,17 +247,19 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
# required shape is [3 * num_heads * head_size, hidden_size].
|
||||
# Thus, we need weight conversion.
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
|
||||
num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // num_heads
|
||||
if 'query_key_value.weight' in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size)
|
||||
if "query_key_value.weight" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size,
|
||||
hidden_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif 'query_key_value.bias' in name:
|
||||
elif "query_key_value.bias" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -30,7 +31,6 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -56,15 +56,19 @@ class LlamaMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
|
||||
bias=False, gather_output=False,
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=False, input_is_parallel=True,
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if hidden_act != 'silu':
|
||||
raise ValueError(f'Unsupported activation: {hidden_act}. '
|
||||
'Only silu is supported for now.')
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
@@ -83,12 +87,14 @@ class LlamaAttention(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
@@ -104,8 +110,10 @@ class LlamaAttention(nn.Module):
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim,
|
||||
self.scaling, rotary_dim=self.head_dim)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -118,8 +126,8 @@ class LlamaAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -138,8 +146,10 @@ class LlamaDecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -177,9 +187,13 @@ class LlamaModel(nn.Module):
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
@@ -209,6 +223,7 @@ class LlamaModel(nn.Module):
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -228,39 +243,42 @@ class LlamaForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head.weight, hidden_states, input_metadata)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
|
||||
"qkv_proj.weight", "gate_proj.weight",
|
||||
"up_proj.weight"]
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
|
||||
"gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
@@ -275,10 +293,10 @@ class LlamaForCausalLM(nn.Module):
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models don't have this hack
|
||||
# OPT is set up so that if padding_idx is specified then offset the
|
||||
# embedding ids by 2 and adjust num_embeddings appropriately. Other
|
||||
# models don't have this hack
|
||||
self.offset = 2
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim)
|
||||
|
||||
@@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
total_num_heads = num_heads
|
||||
assert num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = embed_dim // total_num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
|
||||
self.qkv_proj = ColumnParallelLinear(embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
self.out_proj = RowParallelLinear(embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scaling)
|
||||
|
||||
def forward(
|
||||
@@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.fc1 = ColumnParallelLinear(self.embed_dim,
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
|
||||
self.fc2 = RowParallelLinear(config.ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -133,11 +146,10 @@ class OPTDecoderLayer(nn.Module):
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event)
|
||||
hidden_states = residual + hidden_states
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
@@ -167,35 +179,42 @@ class OPTDecoder(nn.Module):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.word_embed_proj_dim,
|
||||
perform_initialization=False)
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.word_embed_proj_dim,
|
||||
perform_initialization=False)
|
||||
# Positional embeddings are replicated (not sharded).
|
||||
self.embed_positions = OPTLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# Project out & in will be replicated if they exist.
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
|
||||
self.project_out = nn.Linear(config.hidden_size,
|
||||
config.word_embed_proj_dim,
|
||||
bias=False)
|
||||
else:
|
||||
self.project_out = None
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
|
||||
self.project_in = nn.Linear(config.word_embed_proj_dim,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||
# Note that the only purpose of `config._remove_final_layer_norm` is to
|
||||
# keep backward compatibility with checkpoints that have been fine-tuned
|
||||
# before transformers v4.20.1
|
||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
|
||||
)
|
||||
config.hidden_size,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layers = nn.ModuleList(
|
||||
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
@@ -241,8 +260,8 @@ class OPTModel(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
return self.decoder(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
return self.decoder(input_ids, positions, kv_caches, input_metadata,
|
||||
cache_events)
|
||||
|
||||
|
||||
class OPTForCausalLM(nn.Module):
|
||||
@@ -264,23 +283,26 @@ class OPTForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight", "fc1.weight", "fc1.bias"
|
||||
]
|
||||
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
|
||||
@@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
|
||||
name = "model." + name
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
|
||||
@@ -44,9 +44,9 @@ def hf_model_weights_iterator(
|
||||
if use_np_cache:
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, 'np')
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
os.makedirs(np_folder, exist_ok=True)
|
||||
weight_names_file = os.path.join(np_folder, 'weight_names.json')
|
||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||
with lock:
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names = []
|
||||
@@ -57,10 +57,10 @@ def hf_model_weights_iterator(
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
weight_names.append(name)
|
||||
with open(weight_names_file, 'w') as f:
|
||||
with open(weight_names_file, "w") as f:
|
||||
json.dump(weight_names, f)
|
||||
|
||||
with open(weight_names_file, 'r') as f:
|
||||
with open(weight_names_file, "r") as f:
|
||||
weight_names = json.load(f)
|
||||
|
||||
for name in weight_names:
|
||||
@@ -86,17 +86,16 @@ def load_tensor_parallel_weights(
|
||||
for p in column_parallel_weight_names:
|
||||
if p in param_name:
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[start_idx:end_idx]
|
||||
break
|
||||
for p in row_parallel_weight_names:
|
||||
if p in param_name:
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[
|
||||
:,
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
||||
break
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f"{param_name} shape mismatch between model and checkpoint: "
|
||||
|
||||
Reference in New Issue
Block a user