Revert MptConfig to MPTConfig (#1668)

This commit is contained in:
Megha Agarwal
2023-11-16 01:19:39 -08:00
committed by GitHub
parent 7076fa1c9f
commit b514d3c496
6 changed files with 260 additions and 26 deletions

View File

@@ -5,7 +5,6 @@ from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from transformers import MptConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
@@ -22,6 +21,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -39,21 +39,21 @@ def _get_alibi_slopes(
return slopes
class MptAttention(nn.Module):
class MPTAttention(nn.Module):
def __init__(
self,
config: MptConfig,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.clip_qkv = config.attn_config.clip_qkv
self.qk_ln = config.attn_config.qk_ln
self.alibi_bias_max = config.attn_config.alibi_bias_max
assert not config.attn_config.prefix_lm
assert config.attn_config.alibi
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]
# pylint: disable=invalid-name
self.Wqkv = QKVParallelLinear(
@@ -113,11 +113,11 @@ class MptAttention(nn.Module):
return output
class MptMLP(nn.Module):
class MPTMLP(nn.Module):
def __init__(
self,
config: MptConfig,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
@@ -145,19 +145,19 @@ class MptMLP(nn.Module):
return x
class MptBlock(nn.Module):
class MPTBlock(nn.Module):
def __init__(
self,
config: MptConfig,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MptAttention(config, linear_method)
self.attn = MPTAttention(config, linear_method)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MptMLP(config, linear_method)
self.ffn = MPTMLP(config, linear_method)
def forward(
self,
@@ -182,11 +182,11 @@ class MptBlock(nn.Module):
return hidden_states
class MptModel(nn.Module):
class MPTModel(nn.Module):
def __init__(
self,
config: MptConfig,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
@@ -198,7 +198,7 @@ class MptModel(nn.Module):
config.d_model,
)
self.blocks = nn.ModuleList(
[MptBlock(config, linear_method) for _ in range(config.n_layers)])
[MPTBlock(config, linear_method) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
@@ -233,11 +233,11 @@ class MptModel(nn.Module):
return hidden_states
class MptForCausalLM(nn.Module):
class MPTForCausalLM(nn.Module):
def __init__(
self,
config: MptConfig,
config: MPTConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
@@ -245,7 +245,7 @@ class MptForCausalLM(nn.Module):
assert config.tie_word_embeddings
self.linear_method = linear_method
self.transformer = MptModel(config, linear_method)
self.transformer = MPTModel(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)