Remove MPTConfig (#1529)

This commit is contained in:
Woosuk Kwon
2023-11-01 15:29:05 -07:00
committed by GitHub
parent 7e90a2d117
commit 1fe0990023
6 changed files with 26 additions and 102 deletions

View File

@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from transformers import MptConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
@@ -19,7 +20,6 @@ from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -37,17 +37,17 @@ def _get_alibi_slopes(
return slopes
class MPTAttention(nn.Module):
class MptAttention(nn.Module):
def __init__(self, config: MPTConfig):
def __init__(self, config: MptConfig):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]
self.clip_qkv = config.attn_config.clip_qkv
self.qk_ln = config.attn_config.qk_ln
self.alibi_bias_max = config.attn_config.alibi_bias_max
assert not config.attn_config.prefix_lm
assert config.attn_config.alibi
self.qkv_proj = ColumnParallelLinear(
self.d_model,
@@ -105,9 +105,9 @@ class MPTAttention(nn.Module):
return output
class MPTMLP(nn.Module):
class MptMLP(nn.Module):
def __init__(self, config: MPTConfig):
def __init__(self, config: MptConfig):
super().__init__()
hidden_size = config.d_model
expansion_ratio = config.expansion_ratio
@@ -133,15 +133,15 @@ class MPTMLP(nn.Module):
return x
class MPTBlock(nn.Module):
class MptBlock(nn.Module):
def __init__(self, config: MPTConfig):
def __init__(self, config: MptConfig):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config)
self.attn = MptAttention(config)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config)
self.ffn = MptMLP(config)
def forward(
self,
@@ -166,9 +166,9 @@ class MPTBlock(nn.Module):
return hidden_states
class MPTModel(nn.Module):
class MptModel(nn.Module):
def __init__(self, config: MPTConfig):
def __init__(self, config: MptConfig):
super().__init__()
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"
@@ -178,7 +178,7 @@ class MPTModel(nn.Module):
config.d_model,
)
self.blocks = nn.ModuleList(
[MPTBlock(config) for _ in range(config.n_layers)])
[MptBlock(config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
@@ -213,14 +213,14 @@ class MPTModel(nn.Module):
return hidden_states
class MPTForCausalLM(nn.Module):
class MptForCausalLM(nn.Module):
def __init__(self, config: MPTConfig):
def __init__(self, config: MptConfig):
super().__init__()
self.config = config
assert config.tie_word_embeddings
self.transformer = MPTModel(config)
self.transformer = MptModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.wte.weight