Remove MPTConfig (#1529)
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import MptConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@@ -19,7 +20,6 @@ from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@@ -37,17 +37,17 @@ def _get_alibi_slopes(
|
||||
return slopes
|
||||
|
||||
|
||||
class MPTAttention(nn.Module):
|
||||
class MptAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: MPTConfig):
|
||||
def __init__(self, config: MptConfig):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.total_num_heads = config.n_heads
|
||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||
self.qk_ln = config.attn_config["qk_ln"]
|
||||
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
||||
assert not config.attn_config["prefix_lm"]
|
||||
assert config.attn_config["alibi"]
|
||||
self.clip_qkv = config.attn_config.clip_qkv
|
||||
self.qk_ln = config.attn_config.qk_ln
|
||||
self.alibi_bias_max = config.attn_config.alibi_bias_max
|
||||
assert not config.attn_config.prefix_lm
|
||||
assert config.attn_config.alibi
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
self.d_model,
|
||||
@@ -105,9 +105,9 @@ class MPTAttention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class MPTMLP(nn.Module):
|
||||
class MptMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: MPTConfig):
|
||||
def __init__(self, config: MptConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.d_model
|
||||
expansion_ratio = config.expansion_ratio
|
||||
@@ -133,15 +133,15 @@ class MPTMLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class MPTBlock(nn.Module):
|
||||
class MptBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: MPTConfig):
|
||||
def __init__(self, config: MptConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.d_model
|
||||
self.norm_1 = nn.LayerNorm(hidden_size)
|
||||
self.attn = MPTAttention(config)
|
||||
self.attn = MptAttention(config)
|
||||
self.norm_2 = nn.LayerNorm(hidden_size)
|
||||
self.ffn = MPTMLP(config)
|
||||
self.ffn = MptMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -166,9 +166,9 @@ class MPTBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MPTModel(nn.Module):
|
||||
class MptModel(nn.Module):
|
||||
|
||||
def __init__(self, config: MPTConfig):
|
||||
def __init__(self, config: MptConfig):
|
||||
super().__init__()
|
||||
assert config.embedding_fraction == 1.0
|
||||
assert config.norm_type == "low_precision_layernorm"
|
||||
@@ -178,7 +178,7 @@ class MPTModel(nn.Module):
|
||||
config.d_model,
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[MPTBlock(config) for _ in range(config.n_layers)])
|
||||
[MptBlock(config) for _ in range(config.n_layers)])
|
||||
self.norm_f = nn.LayerNorm(config.d_model)
|
||||
if config.no_bias:
|
||||
for module in self.modules():
|
||||
@@ -213,14 +213,14 @@ class MPTModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MPTForCausalLM(nn.Module):
|
||||
class MptForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: MPTConfig):
|
||||
def __init__(self, config: MptConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.tie_word_embeddings
|
||||
|
||||
self.transformer = MPTModel(config)
|
||||
self.transformer = MptModel(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
|
||||
Reference in New Issue
Block a user