diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 7064998af..ed201630d 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( @@ -286,7 +287,6 @@ class DefaultModelLoader(BaseModelLoader): ): self.load_config.safetensors_load_strategy = "torchao" - weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights(self.get_all_weights(model_config, model)) self.counter_after_loading_weights = time.perf_counter() @@ -295,9 +295,20 @@ class DefaultModelLoader(BaseModelLoader): self.counter_after_loading_weights - self.counter_before_loading_weights, scope="local", ) - # We only enable strict check for non-quantized models - # that have loaded weights tracking currently. - if model_config.quantization is None and loaded_weights is not None: + self.track_weights_loading(model, loaded_weights) + + def track_weights_loading( + self, model: nn.Module, loaded_weights: set[str] | None + ) -> None: + weights_to_load = {name for name, _ in model.named_parameters()} + if loaded_weights is not None: + for name, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + # ignore kv_cache scale, which can be missing in checkpoints + if isinstance(quant_method, BaseKVCacheMethod): + for param_name, _ in module.named_parameters(): + full_name = f"{name}.{param_name}" if name else param_name + loaded_weights.add(full_name) weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: raise ValueError(