[5/N] pass the whole config to model (#9983)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-08 22:17:28 -08:00
committed by GitHub
parent 49d2a41a86
commit 1a95f10ee7
75 changed files with 583 additions and 654 deletions

View File

@@ -21,7 +21,7 @@ from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
@@ -245,12 +245,13 @@ class Gemma2Model(nn.Module):
def __init__(
self,
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(
@@ -400,11 +401,13 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(
self,
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
del lora_config # Unused.
super().__init__()
self.config = config
@@ -470,14 +473,14 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
def __init__(
self,
pooler_config: Optional[PoolerConfig] = None,
**kwargs,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
self.model = Gemma2Model(**kwargs)
self.model = Gemma2Model(vllm_config, prefix)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)