diff --git a/vllm/config.py b/vllm/config.py index 79c9609eb..07dafe51b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -122,15 +122,10 @@ class ModelConfig: # TODO: Remove this check once HF updates the pt weights of Mixtral. architectures = getattr(self.hf_config, "architectures", []) - if "MixtralForCausalLM" in architectures: - if load_format == "pt": - raise ValueError( - "Currently, the 'pt' format is not supported for Mixtral. " - "Please use the 'safetensors' format instead. ") - elif load_format == "auto": - # Do not fall back to pt weights. - load_format = "safetensors" - + if "MixtralForCausalLM" in architectures and load_format == "pt": + raise ValueError( + "Currently, the 'pt' format is not supported for Mixtral. " + "Please use the 'safetensors' format instead. ") self.load_format = load_format def _verify_tokenizer_mode(self) -> None: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a3f1582a3..3c6593d83 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -412,7 +412,11 @@ class MixtralForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, + cache_dir, + load_format, + revision, + fall_back_to_pt=False): if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 36ad0f389..bff4fb2f7 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -125,15 +125,29 @@ def get_quant_config( def prepare_hf_model_weights( model_name_or_path: str, cache_dir: Optional[str] = None, - use_safetensors: bool = False, + load_format: str = "auto", fall_back_to_pt: bool = True, revision: Optional[str] = None, ) -> Tuple[str, List[str], bool]: # Download model weights from huggingface. is_local = os.path.isdir(model_name_or_path) + use_safetensors = False # Some quantized models use .pt files for storing the weights. - allow_patterns = ["*.safetensors" - ] if use_safetensors else ["*.bin", "*.pt"] + if load_format == "auto": + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == "safetensors": + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == "pt": + allow_patterns = ["*.pt"] + elif load_format == "npcache": + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += [".pt"] + if not is_local: # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. @@ -148,6 +162,10 @@ def prepare_hf_model_weights( hf_weights_files: List[str] = [] for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break if not use_safetensors: # Exclude files that are not needed for inference. # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 @@ -163,13 +181,6 @@ def prepare_hf_model_weights( if not any(f.endswith(x) for x in blacklist) ] - if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt: - return prepare_hf_model_weights(model_name_or_path, - cache_dir=cache_dir, - use_safetensors=False, - fall_back_to_pt=False, - revision=revision) - if len(hf_weights_files) == 0: raise RuntimeError( f"Cannot find any model weights with `{model_name_or_path}`") @@ -182,30 +193,16 @@ def hf_model_weights_iterator( cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None, + fall_back_to_pt: Optional[bool] = True, ) -> Iterator[Tuple[str, torch.Tensor]]: - use_safetensors = False - use_np_cache = False - fall_back_to_pt = False - if load_format == "auto": - use_safetensors = True - fall_back_to_pt = True - elif load_format == "safetensors": - use_safetensors = True - elif load_format == "pt": - pass - elif load_format == "npcache": - use_np_cache = True - else: - raise ValueError(f"Unknown load_format: {load_format}") - hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( model_name_or_path, cache_dir=cache_dir, - use_safetensors=use_safetensors, + load_format=load_format, fall_back_to_pt=fall_back_to_pt, revision=revision) - if use_np_cache: + if load_format == "npcache": # Currently np_cache only support *.bin checkpoints assert use_safetensors is False