[Core] Refactor model loading code (#4097)

This commit is contained in:
Antoni Baum
2024-04-16 11:34:39 -07:00
committed by GitHub
parent 05434764cd
commit 69e1d2fb69
67 changed files with 1054 additions and 963 deletions

View File

@@ -4,9 +4,9 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
from transformers import PreTrainedTokenizer
import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, TensorizerConfig,
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
@@ -72,11 +72,11 @@ class LLMEngine:
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
tensorizer_config: Optional[TensorizerConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@@ -92,8 +92,8 @@ class LLMEngine:
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, "
f"max_seq_len={model_config.max_model_len}, "
f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, "
f"download_dir={load_config.download_dir!r}, "
f"load_format={load_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"disable_custom_all_reduce="
f"{parallel_config.disable_custom_all_reduce}, "
@@ -114,8 +114,8 @@ class LLMEngine:
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.tensorizer_config = tensorizer_config
self.log_stats = log_stats
self._init_tokenizer()
@@ -131,7 +131,7 @@ class LLMEngine:
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
tensorizer_config=tensorizer_config,
load_config=load_config,
)
self._initialize_kv_caches()
@@ -271,9 +271,6 @@ class LLMEngine:
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.tensorizer_config:
self.tensorizer_config.verify_with_parallel_config(
self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(