[6/N] pass whole config to inner model (#10205)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -253,13 +253,18 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class BaiChuanModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
position_embedding: str = "ROPE",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
@@ -332,21 +337,22 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
position_embedding: str = "ROPE",
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = BaiChuanModel(config, position_embedding, cache_config,
|
||||
quant_config)
|
||||
self.model = BaiChuanModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
position_embedding=position_embedding)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
@@ -438,16 +444,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
NOTE: the class name has a lower case 'c'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
if config.hidden_size == 4096: # baichuan2 7b
|
||||
super().__init__(vllm_config, prefix, "ROPE")
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
position_embedding="ROPE")
|
||||
else: # baichuan 13b, baichuan2 13b
|
||||
super().__init__(vllm_config, prefix, "ALIBI")
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
position_embedding="ALIBI")
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
@@ -455,9 +461,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
NOTE: the class name has an upper case 'C'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(vllm_config, prefix, "ROPE")
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
position_embedding="ROPE")
|
||||
|
||||
Reference in New Issue
Block a user