Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -33,16 +33,18 @@ def check_outputs_equal(
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
for prompt_idx, (outputs_0, outputs_1) in enumerate(
zip(outputs_0_lst, outputs_1_lst)
):
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1
# The text and token outputs should exactly match
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
fail_msg = (
f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}"
)
assert output_str_0 == output_str_1, fail_msg
assert output_ids_0 == output_ids_1, fail_msg
@@ -54,9 +56,9 @@ def check_outputs_equal(
# * List of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TokensTextLogprobs = tuple[list[int], str, Optional[Union[list[dict[int,
float]],
SampleLogprobs]]]
TokensTextLogprobs = tuple[
list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]]
]
# Allow for tokens to be represented as str's rather than IDs;
# tuple of
@@ -65,9 +67,9 @@ TokensTextLogprobs = tuple[list[int], str, Optional[Union[list[dict[int,
# * Optional list of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TextTextLogprobs = tuple[list[str], str, Optional[Union[list[dict[str, float]],
list[dict[str,
Logprob]]]]]
TextTextLogprobs = tuple[
list[str], str, Optional[Union[list[dict[str, float]], list[dict[str, Logprob]]]]
]
# Representation of generated sequence as a tuple of
# * Token ID list
@@ -77,18 +79,21 @@ TextTextLogprobs = tuple[list[str], str, Optional[Union[list[dict[str, float]],
#
# Allows prompt logprobs to be requested.
TokensTextLogprobsPromptLogprobs = tuple[
list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]],
Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]]]
list[int],
str,
Optional[Union[list[dict[int, float]], SampleLogprobs]],
Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]],
]
def check_logprobs_close(
*,
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
outputs_0_lst: Sequence[
Union[TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, TextTextLogprobs]
],
outputs_1_lst: Sequence[
Union[TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, TextTextLogprobs]
],
name_0: str,
name_1: str,
num_outputs_0_skip_tokens: int = 0,
@@ -128,9 +133,9 @@ def check_logprobs_close(
assert len(outputs_0_lst) == len(outputs_1_lst)
# Loop through responses to each prompt.
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
for prompt_idx, (outputs_0, outputs_1) in enumerate(
zip(outputs_0_lst, outputs_1_lst)
):
assert len(outputs_0) == len(outputs_1)
if len(outputs_0) == 3:
assert len(outputs_1) == 3
@@ -155,17 +160,18 @@ def check_logprobs_close(
) = outputs_1
# Test prompt logprobs closeness
if (prompt_logprobs_0 is not None
and prompt_logprobs_1 is not None):
if prompt_logprobs_0 is not None and prompt_logprobs_1 is not None:
# Both sequences' prompt logprobs lists are not `None``
# (although individual list elements may be `None`);
# for each token's logprobs:
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
zip(prompt_logprobs_0, prompt_logprobs_1)):
zip(prompt_logprobs_0, prompt_logprobs_1)
):
fail_msg = (
f"Prompt logprobs test:"
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}"
)
if logprobs_elem_0 is None:
# If the seq 0 token's logprobs are `None`,
@@ -176,20 +182,24 @@ def check_logprobs_close(
# the seq 1 token's logprobs must not be `None`
assert logprobs_elem_1 is not None, fail_msg
# Logprobs check: top-k token choices must be the same
assert (set(logprobs_elem_0.keys()) == set(
logprobs_elem_1.keys())), fail_msg
assert set(logprobs_elem_0.keys()) == set(
logprobs_elem_1.keys()
), fail_msg
else:
# Both sequence logprobs lists must be `None`
fail_msg = (f"Prompt logprobs test:"
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
fail_msg = (
f"Prompt logprobs test:"
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}"
)
assert (prompt_logprobs_0 is None
and prompt_logprobs_1 is None), fail_msg
assert prompt_logprobs_0 is None and prompt_logprobs_1 is None, fail_msg
else:
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
f"{len(outputs_0)} elements were provided: "
f"{outputs_0}")
raise ValueError(
f"Outputs tuple must have 3 or 4 elements but "
f"{len(outputs_0)} elements were provided: "
f"{outputs_0}"
)
if logprobs_0 is None:
logprobs_0 = [None] * len(output_ids_0)
@@ -206,9 +216,9 @@ def check_logprobs_close(
logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
# Loop through generated tokens.
for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
for idx, (output_id_0, output_id_1) in enumerate(
zip(output_ids_0, output_ids_1)
):
is_tok_mismatch = output_id_0 != output_id_1
# If generated tokens don't match
@@ -223,7 +233,8 @@ def check_logprobs_close(
f"Test{prompt_idx}:"
f"\nMatched tokens:\t{output_ids_0[:idx]}"
f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}"
)
assert logprobs_elem_0 is not None, fail_msg
assert logprobs_elem_1 is not None, fail_msg
@@ -244,9 +255,11 @@ def check_logprobs_close(
if output_str_0 != output_str_1 and warn_on_mismatch:
# The token outputs exactly match,
# so the text outputs should exactly match as well
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
fail_msg = (
f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}"
)
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
@@ -317,18 +330,22 @@ def check_embeddings_close(
assert len(embeddings_0_lst) == len(embeddings_1_lst)
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
zip(embeddings_0_lst, embeddings_1_lst)):
zip(embeddings_0_lst, embeddings_1_lst)
):
assert len(embeddings_0) == len(embeddings_1), (
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}"
)
sim = F.cosine_similarity(torch.tensor(embeddings_0),
torch.tensor(embeddings_1),
dim=0)
sim = F.cosine_similarity(
torch.tensor(embeddings_0), torch.tensor(embeddings_1), dim=0
)
fail_msg = (f"Test{prompt_idx}:"
f"\nCosine similarity: \t{sim:.4f}"
f"\n{name_0}:\t{embeddings_0[:16]!r}"
f"\n{name_1}:\t{embeddings_1[:16]!r}")
fail_msg = (
f"Test{prompt_idx}:"
f"\nCosine similarity: \t{sim:.4f}"
f"\n{name_0}:\t{embeddings_0[:16]!r}"
f"\n{name_1}:\t{embeddings_1[:16]!r}"
)
assert sim >= 1 - tol, fail_msg
@@ -413,20 +430,19 @@ def dummy_hf_overrides(
# Ensure at least 2 expert per group
# Since `grouped_topk` assumes top-2
n_group = getattr(text_config, 'n_group', None)
n_group = getattr(text_config, "n_group", None)
num_experts = n_group * 2 if n_group is not None else 2
# we use three layers for Gemma-3n to check
# both normal layer and kv_shared_layer
if use_original_num_layers:
# Use the original number of layers from the config
num_layers = getattr(text_config, 'num_layers', 1)
num_hidden_layers = getattr(text_config, 'num_hidden_layers', 1)
num_layers = getattr(text_config, "num_layers", 1)
num_hidden_layers = getattr(text_config, "num_hidden_layers", 1)
else:
# Use minimal layers for testing
num_layers = 1
num_hidden_layers = (3 if model_arch
== "Gemma3nForConditionalGeneration" else 1)
num_hidden_layers = 3 if model_arch == "Gemma3nForConditionalGeneration" else 1
update_dict = {
"num_layers": num_layers,
@@ -440,53 +456,63 @@ def dummy_hf_overrides(
# Only set MoE related config when the model has MoE layers.
# Otherwise all models detected as MoE by _get_transformers_backend_cls.
if ModelConfig.get_num_experts(DummyConfig) > 0:
update_dict.update({
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
# Otherwise there will not be any expert layers
"first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts,
})
update_dict.update(
{
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
# Otherwise there will not be any expert layers
"first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts,
}
)
# Update num_hidden_layers for non-Longcat architectures
if model_arch != "LongcatFlashForCausalLM" \
and model_arch != "LongCatFlashMTPModel":
if model_arch != "LongcatFlashForCausalLM" and model_arch != "LongCatFlashMTPModel":
update_dict["num_hidden_layers"] = num_hidden_layers
text_config.update(update_dict)
if hasattr(hf_config, "vision_config"):
hf_config.vision_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
hf_config.vision_config.update(
{
"num_layers": 1,
"num_hidden_layers": 1,
}
)
# e.g.: ibm-granite/granite-speech-3.3-2b
if hasattr(hf_config, "encoder_config"):
hf_config.encoder_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
hf_config.encoder_config.update(
{
"num_layers": 1,
"num_hidden_layers": 1,
}
)
# e.g.: Qwen/Qwen2-Audio-7B-Instruct
if hasattr(hf_config, "audio_config"):
hf_config.audio_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
"encoder_layers": 1,
})
hf_config.audio_config.update(
{
"num_layers": 1,
"num_hidden_layers": 1,
"encoder_layers": 1,
}
)
return hf_config
def check_transformers_version(model: str,
min_transformers_version: Optional[str] = None,
max_transformers_version: Optional[str] = None):
def check_transformers_version(
model: str,
min_transformers_version: Optional[str] = None,
max_transformers_version: Optional[str] = None,
):
from .registry import _HfExamplesInfo
return _HfExamplesInfo(model,
min_transformers_version=min_transformers_version,
max_transformers_version=max_transformers_version
).check_transformers_version(on_fail="skip")
return _HfExamplesInfo(
model,
min_transformers_version=min_transformers_version,
max_transformers_version=max_transformers_version,
).check_transformers_version(on_fail="skip")