Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -16,12 +16,18 @@ from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
download_safetensors_index_file_from_hf,
download_weights_from_hf,
fastsafetensors_weights_iterator,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
maybe_download_from_modelscope,
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
multi_thread_safetensors_weights_iterator,
np_cache_weights_iterator,
pt_weights_iterator,
safetensors_weights_iterator,
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -63,9 +69,11 @@ class DefaultModelLoader(BaseModelLoader):
unexpected_keys = set(extra_config.keys()) - allowed_keys
if unexpected_keys:
raise ValueError(f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{unexpected_keys}")
raise ValueError(
f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{unexpected_keys}"
)
def _prepare_weights(
self,
@@ -77,8 +85,10 @@ class DefaultModelLoader(BaseModelLoader):
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = (maybe_download_from_modelscope(
model_name_or_path, revision) or model_name_or_path)
model_name_or_path = (
maybe_download_from_modelscope(model_name_or_path, revision)
or model_name_or_path
)
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
@@ -87,8 +97,7 @@ class DefaultModelLoader(BaseModelLoader):
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif (load_format == "safetensors"
or load_format == "fastsafetensors"):
elif load_format == "safetensors" or load_format == "fastsafetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "mistral":
@@ -141,25 +150,29 @@ class DefaultModelLoader(BaseModelLoader):
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
hf_weights_files, hf_folder, index_file
)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
f"Cannot find any model weights with `{model_name_or_path}`"
)
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, source: "Source"
self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
extra_config = self.load_config.model_loader_extra_config
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt,
source.allow_patterns_overrides)
source.model_or_path,
source.revision,
source.fall_back_to_pt,
source.allow_patterns_overrides,
)
if self.load_config.load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
@@ -178,13 +191,13 @@ class DefaultModelLoader(BaseModelLoader):
)
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = (
multi_thread_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS),
))
weights_iterator = multi_thread_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
@@ -197,8 +210,9 @@ class DefaultModelLoader(BaseModelLoader):
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
max_workers=extra_config.get("num_threads",
self.DEFAULT_NUM_THREADS),
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = pt_weights_iterator(
@@ -226,8 +240,7 @@ class DefaultModelLoader(BaseModelLoader):
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
def get_all_weights(
self,
@@ -238,10 +251,8 @@ class DefaultModelLoader(BaseModelLoader):
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
None),
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
)
yield from self._get_weights_iterator(primary_weights)
@@ -253,13 +264,14 @@ class DefaultModelLoader(BaseModelLoader):
yield from self._get_weights_iterator(source)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model,
model_config.revision,
fall_back_to_pt=True,
allow_patterns_overrides=None)
self._prepare_weights(
model_config.model,
model_config.revision,
fall_back_to_pt=True,
allow_patterns_overrides=None,
)
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
weights_to_load = {name for name, _ in model.named_parameters()}
# if we don't have `model.weight_metadata_and_attr_saved` defined and
@@ -267,38 +279,43 @@ class DefaultModelLoader(BaseModelLoader):
# or the first run of online quantization
# see online_quantization.py for detailed notes
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False)
model, "weight_metadata_and_attr_saved", False
)
if model_config.quantization is None:
# model is not quantized
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model))
self.get_all_weights(model_config, model)
)
elif offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
# see online_quantization.py for detailed notes
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model))
self.get_all_weights(model_config, model)
)
else:
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
load_weights_and_online_quantize)
load_weights_and_online_quantize,
)
# subsequent runs of weight loading with online
# quantization
loaded_weights = load_weights_and_online_quantize(
self, model, model_config)
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights -
self.counter_before_loading_weights)
self.counter_after_loading_weights - self.counter_before_loading_weights,
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)