diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index ed201630d..7064998af 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -14,7 +14,6 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( @@ -287,6 +286,7 @@ class DefaultModelLoader(BaseModelLoader): ): self.load_config.safetensors_load_strategy = "torchao" + weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights(self.get_all_weights(model_config, model)) self.counter_after_loading_weights = time.perf_counter() @@ -295,20 +295,9 @@ class DefaultModelLoader(BaseModelLoader): self.counter_after_loading_weights - self.counter_before_loading_weights, scope="local", ) - self.track_weights_loading(model, loaded_weights) - - def track_weights_loading( - self, model: nn.Module, loaded_weights: set[str] | None - ) -> None: - weights_to_load = {name for name, _ in model.named_parameters()} - if loaded_weights is not None: - for name, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - # ignore kv_cache scale, which can be missing in checkpoints - if isinstance(quant_method, BaseKVCacheMethod): - for param_name, _ in module.named_parameters(): - full_name = f"{name}.{param_name}" if name else param_name - loaded_weights.add(full_name) + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + if model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: raise ValueError(