Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -19,7 +19,7 @@ def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
"""Create weights from safetensors checkpoint metadata"""
|
||||
metadata = try_get_safetensors_metadata(repo)
|
||||
weight_names = list(metadata.weight_map.keys())
|
||||
with torch.device('meta'):
|
||||
with torch.device("meta"):
|
||||
return ((name, torch.empty(0)) for name in weight_names)
|
||||
|
||||
|
||||
@@ -61,7 +61,8 @@ def test_hf_model_weights_mapper(model_arch: str):
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
original_weights = create_repo_dummy_weights(model_id)
|
||||
@@ -83,6 +84,7 @@ def test_hf_model_weights_mapper(model_arch: str):
|
||||
|
||||
weights_missing = ref_weight_names - weight_names
|
||||
weights_unmapped = weight_names - ref_weight_names
|
||||
assert (not weights_missing and not weights_unmapped), (
|
||||
assert not weights_missing and not weights_unmapped, (
|
||||
f"Following weights are not mapped correctly: {weights_unmapped}, "
|
||||
f"Missing expected weights: {weights_missing}.")
|
||||
f"Missing expected weights: {weights_missing}."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user