[5/N] pass the whole config to model (#9983)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-08 22:17:28 -08:00
committed by GitHub
parent 49d2a41a86
commit 1a95f10ee7
75 changed files with 583 additions and 654 deletions

View File

@@ -26,7 +26,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
@@ -332,14 +332,15 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(
self,
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
position_embedding: str = "ROPE",
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
@@ -439,17 +440,14 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
config = vllm_config.model_config.hf_config
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", cache_config, quant_config,
lora_config)
super().__init__(vllm_config, prefix, "ROPE")
else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", cache_config, quant_config,
lora_config)
super().__init__(vllm_config, prefix, "ALIBI")
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
@@ -459,10 +457,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(config, "ROPE", cache_config, quant_config,
lora_config)
super().__init__(vllm_config, prefix, "ROPE")