Revert MptConfig to MPTConfig (#1668)
This commit is contained in:
@@ -5,7 +5,6 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import MptConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@@ -22,6 +21,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@@ -39,21 +39,21 @@ def _get_alibi_slopes(
|
||||
return slopes
|
||||
|
||||
|
||||
class MptAttention(nn.Module):
|
||||
class MPTAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MptConfig,
|
||||
config: MPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.total_num_heads = config.n_heads
|
||||
self.clip_qkv = config.attn_config.clip_qkv
|
||||
self.qk_ln = config.attn_config.qk_ln
|
||||
self.alibi_bias_max = config.attn_config.alibi_bias_max
|
||||
assert not config.attn_config.prefix_lm
|
||||
assert config.attn_config.alibi
|
||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||
self.qk_ln = config.attn_config["qk_ln"]
|
||||
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
||||
assert not config.attn_config["prefix_lm"]
|
||||
assert config.attn_config["alibi"]
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.Wqkv = QKVParallelLinear(
|
||||
@@ -113,11 +113,11 @@ class MptAttention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class MptMLP(nn.Module):
|
||||
class MPTMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MptConfig,
|
||||
config: MPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -145,19 +145,19 @@ class MptMLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class MptBlock(nn.Module):
|
||||
class MPTBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MptConfig,
|
||||
config: MPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.d_model
|
||||
self.norm_1 = nn.LayerNorm(hidden_size)
|
||||
self.attn = MptAttention(config, linear_method)
|
||||
self.attn = MPTAttention(config, linear_method)
|
||||
self.norm_2 = nn.LayerNorm(hidden_size)
|
||||
self.ffn = MptMLP(config, linear_method)
|
||||
self.ffn = MPTMLP(config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -182,11 +182,11 @@ class MptBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MptModel(nn.Module):
|
||||
class MPTModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MptConfig,
|
||||
config: MPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -198,7 +198,7 @@ class MptModel(nn.Module):
|
||||
config.d_model,
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[MptBlock(config, linear_method) for _ in range(config.n_layers)])
|
||||
[MPTBlock(config, linear_method) for _ in range(config.n_layers)])
|
||||
self.norm_f = nn.LayerNorm(config.d_model)
|
||||
if config.no_bias:
|
||||
for module in self.modules():
|
||||
@@ -233,11 +233,11 @@ class MptModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MptForCausalLM(nn.Module):
|
||||
class MPTForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MptConfig,
|
||||
config: MPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -245,7 +245,7 @@ class MptForCausalLM(nn.Module):
|
||||
assert config.tie_word_embeddings
|
||||
self.linear_method = linear_method
|
||||
|
||||
self.transformer = MptModel(config, linear_method)
|
||||
self.transformer = MPTModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user