[5/N] pass the whole config to model (#9983)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-08 22:17:28 -08:00
committed by GitHub
parent 49d2a41a86
commit 1a95f10ee7
75 changed files with 583 additions and 654 deletions

View File

@@ -7,14 +7,12 @@ from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import Qwen2Config
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
@@ -59,12 +57,15 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
@@ -77,8 +78,6 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
config.num_hidden_layers,
))
super().__init__()
self.config = config
self.lora_config = lora_config