[CI/Build] Fix registry tests (#21934)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -8,7 +8,7 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import MptConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
@@ -50,7 +50,7 @@ class MPTAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: MptConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@@ -59,15 +59,15 @@ class MPTAttention(nn.Module):
|
||||
self.d_model = config.d_model
|
||||
self.total_num_heads = config.n_heads
|
||||
self.head_dim = self.d_model // self.total_num_heads
|
||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||
self.qk_ln = config.attn_config["qk_ln"]
|
||||
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
||||
self.clip_qkv = config.attn_config.clip_qkv
|
||||
self.qk_ln = config.attn_config.qk_ln
|
||||
self.alibi_bias_max = config.attn_config.alibi_bias_max
|
||||
if "kv_n_heads" in config.attn_config:
|
||||
self.total_num_kv_heads = config.attn_config['kv_n_heads']
|
||||
self.total_num_kv_heads = config.attn_config.kv_n_heads
|
||||
else:
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
assert not config.attn_config["prefix_lm"]
|
||||
assert config.attn_config["alibi"]
|
||||
assert not config.attn_config.prefix_lm
|
||||
assert config.attn_config.alibi
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.Wqkv = QKVParallelLinear(
|
||||
@@ -144,7 +144,7 @@ class MPTMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: MptConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -176,7 +176,7 @@ class MPTBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: MptConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
|
||||
@@ -37,9 +37,20 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
class TeleChat2Model(LlamaModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
vllm_config.model_config.hf_config.attribute_map = {
|
||||
"num_hidden_layers": "n_layer",
|
||||
"num_attention_heads": "n_head",
|
||||
"intermediate_size": "ffn_hidden_size",
|
||||
"rms_norm_eps": "layer_norm_epsilon"
|
||||
}
|
||||
vllm_config.model_config.hf_config.hidden_act = "silu"
|
||||
|
||||
# 1. Initialize the LlamaModel with bias
|
||||
vllm_config.model_config.hf_config.bias = True
|
||||
vllm_config.model_config.hf_config.mlp_bias = True
|
||||
hf_config.bias = True
|
||||
hf_config.mlp_bias = True
|
||||
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
# 2. Remove the bias from the qkv_proj and gate_up_proj based on config
|
||||
# Telechat2's gate_up_proj and qkv_proj don't have bias
|
||||
|
||||
Reference in New Issue
Block a user