[5/N] pass the whole config to model (#9983)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-08 22:17:28 -08:00
committed by GitHub
parent 49d2a41a86
commit 1a95f10ee7
75 changed files with 583 additions and 654 deletions

View File

@@ -6,7 +6,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
@@ -415,14 +415,16 @@ class ArcticModel(nn.Module):
class ArcticForCausalLM(nn.Module, SupportsPP):
def __init__(self,
config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
**kwargs) -> None:
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.model = ArcticModel(config, cache_config, quant_config)
self.model = ArcticModel(config,
cache_config,
quant_config,
prefix=prefix)
self.vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.vocab_size,