Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -33,16 +33,18 @@ def check_outputs_equal(
|
||||
"""
|
||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||
|
||||
for prompt_idx, (outputs_0,
|
||||
outputs_1) in enumerate(zip(outputs_0_lst,
|
||||
outputs_1_lst)):
|
||||
for prompt_idx, (outputs_0, outputs_1) in enumerate(
|
||||
zip(outputs_0_lst, outputs_1_lst)
|
||||
):
|
||||
output_ids_0, output_str_0 = outputs_0
|
||||
output_ids_1, output_str_1 = outputs_1
|
||||
|
||||
# The text and token outputs should exactly match
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
fail_msg = (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}"
|
||||
)
|
||||
|
||||
assert output_str_0 == output_str_1, fail_msg
|
||||
assert output_ids_0 == output_ids_1, fail_msg
|
||||
@@ -54,9 +56,9 @@ def check_outputs_equal(
|
||||
# * List of top sample logprobs for each sampled token
|
||||
#
|
||||
# Assumes prompt logprobs were not requested.
|
||||
TokensTextLogprobs = tuple[list[int], str, Optional[Union[list[dict[int,
|
||||
float]],
|
||||
SampleLogprobs]]]
|
||||
TokensTextLogprobs = tuple[
|
||||
list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]]
|
||||
]
|
||||
|
||||
# Allow for tokens to be represented as str's rather than IDs;
|
||||
# tuple of
|
||||
@@ -65,9 +67,9 @@ TokensTextLogprobs = tuple[list[int], str, Optional[Union[list[dict[int,
|
||||
# * Optional list of top sample logprobs for each sampled token
|
||||
#
|
||||
# Assumes prompt logprobs were not requested.
|
||||
TextTextLogprobs = tuple[list[str], str, Optional[Union[list[dict[str, float]],
|
||||
list[dict[str,
|
||||
Logprob]]]]]
|
||||
TextTextLogprobs = tuple[
|
||||
list[str], str, Optional[Union[list[dict[str, float]], list[dict[str, Logprob]]]]
|
||||
]
|
||||
|
||||
# Representation of generated sequence as a tuple of
|
||||
# * Token ID list
|
||||
@@ -77,18 +79,21 @@ TextTextLogprobs = tuple[list[str], str, Optional[Union[list[dict[str, float]],
|
||||
#
|
||||
# Allows prompt logprobs to be requested.
|
||||
TokensTextLogprobsPromptLogprobs = tuple[
|
||||
list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]],
|
||||
Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]]]
|
||||
list[int],
|
||||
str,
|
||||
Optional[Union[list[dict[int, float]], SampleLogprobs]],
|
||||
Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]],
|
||||
]
|
||||
|
||||
|
||||
def check_logprobs_close(
|
||||
*,
|
||||
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
TextTextLogprobs]],
|
||||
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
TextTextLogprobs]],
|
||||
outputs_0_lst: Sequence[
|
||||
Union[TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, TextTextLogprobs]
|
||||
],
|
||||
outputs_1_lst: Sequence[
|
||||
Union[TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, TextTextLogprobs]
|
||||
],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
num_outputs_0_skip_tokens: int = 0,
|
||||
@@ -128,9 +133,9 @@ def check_logprobs_close(
|
||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||
|
||||
# Loop through responses to each prompt.
|
||||
for prompt_idx, (outputs_0,
|
||||
outputs_1) in enumerate(zip(outputs_0_lst,
|
||||
outputs_1_lst)):
|
||||
for prompt_idx, (outputs_0, outputs_1) in enumerate(
|
||||
zip(outputs_0_lst, outputs_1_lst)
|
||||
):
|
||||
assert len(outputs_0) == len(outputs_1)
|
||||
if len(outputs_0) == 3:
|
||||
assert len(outputs_1) == 3
|
||||
@@ -155,17 +160,18 @@ def check_logprobs_close(
|
||||
) = outputs_1
|
||||
|
||||
# Test prompt logprobs closeness
|
||||
if (prompt_logprobs_0 is not None
|
||||
and prompt_logprobs_1 is not None):
|
||||
if prompt_logprobs_0 is not None and prompt_logprobs_1 is not None:
|
||||
# Both sequences' prompt logprobs lists are not `None``
|
||||
# (although individual list elements may be `None`);
|
||||
# for each token's logprobs:
|
||||
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
|
||||
zip(prompt_logprobs_0, prompt_logprobs_1)):
|
||||
zip(prompt_logprobs_0, prompt_logprobs_1)
|
||||
):
|
||||
fail_msg = (
|
||||
f"Prompt logprobs test:"
|
||||
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
|
||||
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
|
||||
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}"
|
||||
)
|
||||
|
||||
if logprobs_elem_0 is None:
|
||||
# If the seq 0 token's logprobs are `None`,
|
||||
@@ -176,20 +182,24 @@ def check_logprobs_close(
|
||||
# the seq 1 token's logprobs must not be `None`
|
||||
assert logprobs_elem_1 is not None, fail_msg
|
||||
# Logprobs check: top-k token choices must be the same
|
||||
assert (set(logprobs_elem_0.keys()) == set(
|
||||
logprobs_elem_1.keys())), fail_msg
|
||||
assert set(logprobs_elem_0.keys()) == set(
|
||||
logprobs_elem_1.keys()
|
||||
), fail_msg
|
||||
else:
|
||||
# Both sequence logprobs lists must be `None`
|
||||
fail_msg = (f"Prompt logprobs test:"
|
||||
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
|
||||
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
|
||||
fail_msg = (
|
||||
f"Prompt logprobs test:"
|
||||
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
|
||||
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}"
|
||||
)
|
||||
|
||||
assert (prompt_logprobs_0 is None
|
||||
and prompt_logprobs_1 is None), fail_msg
|
||||
assert prompt_logprobs_0 is None and prompt_logprobs_1 is None, fail_msg
|
||||
else:
|
||||
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
|
||||
f"{len(outputs_0)} elements were provided: "
|
||||
f"{outputs_0}")
|
||||
raise ValueError(
|
||||
f"Outputs tuple must have 3 or 4 elements but "
|
||||
f"{len(outputs_0)} elements were provided: "
|
||||
f"{outputs_0}"
|
||||
)
|
||||
|
||||
if logprobs_0 is None:
|
||||
logprobs_0 = [None] * len(output_ids_0)
|
||||
@@ -206,9 +216,9 @@ def check_logprobs_close(
|
||||
logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
|
||||
|
||||
# Loop through generated tokens.
|
||||
for idx, (output_id_0,
|
||||
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||
|
||||
for idx, (output_id_0, output_id_1) in enumerate(
|
||||
zip(output_ids_0, output_ids_1)
|
||||
):
|
||||
is_tok_mismatch = output_id_0 != output_id_1
|
||||
|
||||
# If generated tokens don't match
|
||||
@@ -223,7 +233,8 @@ def check_logprobs_close(
|
||||
f"Test{prompt_idx}:"
|
||||
f"\nMatched tokens:\t{output_ids_0[:idx]}"
|
||||
f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
|
||||
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
|
||||
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}"
|
||||
)
|
||||
|
||||
assert logprobs_elem_0 is not None, fail_msg
|
||||
assert logprobs_elem_1 is not None, fail_msg
|
||||
@@ -244,9 +255,11 @@ def check_logprobs_close(
|
||||
if output_str_0 != output_str_1 and warn_on_mismatch:
|
||||
# The token outputs exactly match,
|
||||
# so the text outputs should exactly match as well
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
fail_msg = (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}"
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# This ensures that repeated warnings are shown
|
||||
@@ -317,18 +330,22 @@ def check_embeddings_close(
|
||||
assert len(embeddings_0_lst) == len(embeddings_1_lst)
|
||||
|
||||
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
|
||||
zip(embeddings_0_lst, embeddings_1_lst)):
|
||||
zip(embeddings_0_lst, embeddings_1_lst)
|
||||
):
|
||||
assert len(embeddings_0) == len(embeddings_1), (
|
||||
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
|
||||
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}"
|
||||
)
|
||||
|
||||
sim = F.cosine_similarity(torch.tensor(embeddings_0),
|
||||
torch.tensor(embeddings_1),
|
||||
dim=0)
|
||||
sim = F.cosine_similarity(
|
||||
torch.tensor(embeddings_0), torch.tensor(embeddings_1), dim=0
|
||||
)
|
||||
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\nCosine similarity: \t{sim:.4f}"
|
||||
f"\n{name_0}:\t{embeddings_0[:16]!r}"
|
||||
f"\n{name_1}:\t{embeddings_1[:16]!r}")
|
||||
fail_msg = (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\nCosine similarity: \t{sim:.4f}"
|
||||
f"\n{name_0}:\t{embeddings_0[:16]!r}"
|
||||
f"\n{name_1}:\t{embeddings_1[:16]!r}"
|
||||
)
|
||||
|
||||
assert sim >= 1 - tol, fail_msg
|
||||
|
||||
@@ -413,20 +430,19 @@ def dummy_hf_overrides(
|
||||
|
||||
# Ensure at least 2 expert per group
|
||||
# Since `grouped_topk` assumes top-2
|
||||
n_group = getattr(text_config, 'n_group', None)
|
||||
n_group = getattr(text_config, "n_group", None)
|
||||
num_experts = n_group * 2 if n_group is not None else 2
|
||||
|
||||
# we use three layers for Gemma-3n to check
|
||||
# both normal layer and kv_shared_layer
|
||||
if use_original_num_layers:
|
||||
# Use the original number of layers from the config
|
||||
num_layers = getattr(text_config, 'num_layers', 1)
|
||||
num_hidden_layers = getattr(text_config, 'num_hidden_layers', 1)
|
||||
num_layers = getattr(text_config, "num_layers", 1)
|
||||
num_hidden_layers = getattr(text_config, "num_hidden_layers", 1)
|
||||
else:
|
||||
# Use minimal layers for testing
|
||||
num_layers = 1
|
||||
num_hidden_layers = (3 if model_arch
|
||||
== "Gemma3nForConditionalGeneration" else 1)
|
||||
num_hidden_layers = 3 if model_arch == "Gemma3nForConditionalGeneration" else 1
|
||||
|
||||
update_dict = {
|
||||
"num_layers": num_layers,
|
||||
@@ -440,53 +456,63 @@ def dummy_hf_overrides(
|
||||
# Only set MoE related config when the model has MoE layers.
|
||||
# Otherwise all models detected as MoE by _get_transformers_backend_cls.
|
||||
if ModelConfig.get_num_experts(DummyConfig) > 0:
|
||||
update_dict.update({
|
||||
"num_experts": num_experts,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": num_experts,
|
||||
# Otherwise there will not be any expert layers
|
||||
"first_k_dense_replace": 0,
|
||||
# To avoid OOM on DeepSeek-V3
|
||||
"n_routed_experts": num_experts,
|
||||
})
|
||||
update_dict.update(
|
||||
{
|
||||
"num_experts": num_experts,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": num_experts,
|
||||
# Otherwise there will not be any expert layers
|
||||
"first_k_dense_replace": 0,
|
||||
# To avoid OOM on DeepSeek-V3
|
||||
"n_routed_experts": num_experts,
|
||||
}
|
||||
)
|
||||
|
||||
# Update num_hidden_layers for non-Longcat architectures
|
||||
if model_arch != "LongcatFlashForCausalLM" \
|
||||
and model_arch != "LongCatFlashMTPModel":
|
||||
if model_arch != "LongcatFlashForCausalLM" and model_arch != "LongCatFlashMTPModel":
|
||||
update_dict["num_hidden_layers"] = num_hidden_layers
|
||||
|
||||
text_config.update(update_dict)
|
||||
|
||||
if hasattr(hf_config, "vision_config"):
|
||||
hf_config.vision_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
})
|
||||
hf_config.vision_config.update(
|
||||
{
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
}
|
||||
)
|
||||
|
||||
# e.g.: ibm-granite/granite-speech-3.3-2b
|
||||
if hasattr(hf_config, "encoder_config"):
|
||||
hf_config.encoder_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
})
|
||||
hf_config.encoder_config.update(
|
||||
{
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
}
|
||||
)
|
||||
|
||||
# e.g.: Qwen/Qwen2-Audio-7B-Instruct
|
||||
if hasattr(hf_config, "audio_config"):
|
||||
hf_config.audio_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"encoder_layers": 1,
|
||||
})
|
||||
hf_config.audio_config.update(
|
||||
{
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"encoder_layers": 1,
|
||||
}
|
||||
)
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
def check_transformers_version(model: str,
|
||||
min_transformers_version: Optional[str] = None,
|
||||
max_transformers_version: Optional[str] = None):
|
||||
def check_transformers_version(
|
||||
model: str,
|
||||
min_transformers_version: Optional[str] = None,
|
||||
max_transformers_version: Optional[str] = None,
|
||||
):
|
||||
from .registry import _HfExamplesInfo
|
||||
|
||||
return _HfExamplesInfo(model,
|
||||
min_transformers_version=min_transformers_version,
|
||||
max_transformers_version=max_transformers_version
|
||||
).check_transformers_version(on_fail="skip")
|
||||
return _HfExamplesInfo(
|
||||
model,
|
||||
min_transformers_version=min_transformers_version,
|
||||
max_transformers_version=max_transformers_version,
|
||||
).check_transformers_version(on_fail="skip")
|
||||
|
||||
Reference in New Issue
Block a user