[Core] Refactor model loading code (#4097)
This commit is contained in:
@@ -4,9 +4,9 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
import vllm
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TensorizerConfig,
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
@@ -72,11 +72,11 @@ class LLMEngine:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
decoding_config: Optional[DecodingConfig],
|
||||
tensorizer_config: Optional[TensorizerConfig],
|
||||
executor_class: Type[ExecutorBase],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
@@ -92,8 +92,8 @@ class LLMEngine:
|
||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||
f"dtype={model_config.dtype}, "
|
||||
f"max_seq_len={model_config.max_model_len}, "
|
||||
f"download_dir={model_config.download_dir!r}, "
|
||||
f"load_format={model_config.load_format}, "
|
||||
f"download_dir={load_config.download_dir!r}, "
|
||||
f"load_format={load_config.load_format}, "
|
||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"disable_custom_all_reduce="
|
||||
f"{parallel_config.disable_custom_all_reduce}, "
|
||||
@@ -114,8 +114,8 @@ class LLMEngine:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.speculative_config = speculative_config
|
||||
self.load_config = load_config
|
||||
self.decoding_config = decoding_config or DecodingConfig()
|
||||
self.tensorizer_config = tensorizer_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
self._init_tokenizer()
|
||||
@@ -131,7 +131,7 @@ class LLMEngine:
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
speculative_config=speculative_config,
|
||||
tensorizer_config=tensorizer_config,
|
||||
load_config=load_config,
|
||||
)
|
||||
|
||||
self._initialize_kv_caches()
|
||||
@@ -271,9 +271,6 @@ class LLMEngine:
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
if self.tensorizer_config:
|
||||
self.tensorizer_config.verify_with_parallel_config(
|
||||
self.parallel_config)
|
||||
if self.lora_config:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
|
||||
Reference in New Issue
Block a user