Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -62,7 +62,7 @@ class NemotronConfig(PretrainedConfig):
|
||||
(MQA) otherwise GQA is used. When converting a multi-head
|
||||
checkpoint to a GQA checkpoint, each group key and value
|
||||
head should be constructed by meanpooling all the original
|
||||
heads within that group. For more details checkout
|
||||
heads within that group. For more details checkout
|
||||
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it
|
||||
is not specified, will default to `num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
|
||||
@@ -147,8 +147,9 @@ class NemotronConfig(PretrainedConfig):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
head_dim = head_dim or kwargs.get("kv_channels")
|
||||
self.head_dim = head_dim if head_dim is not None else (
|
||||
hidden_size // num_attention_heads)
|
||||
self.head_dim = (
|
||||
head_dim if head_dim is not None else (hidden_size // num_attention_heads)
|
||||
)
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
@@ -162,8 +163,11 @@ class NemotronConfig(PretrainedConfig):
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
# for backward compatibility
|
||||
partial_rotary_factor = kwargs.get("rope_percent") or kwargs.get(
|
||||
"rope_percentage") or partial_rotary_factor
|
||||
partial_rotary_factor = (
|
||||
kwargs.get("rope_percent")
|
||||
or kwargs.get("rope_percentage")
|
||||
or partial_rotary_factor
|
||||
)
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self._rope_scaling_validation()
|
||||
self.attention_bias = attention_bias
|
||||
@@ -185,21 +189,24 @@ class NemotronConfig(PretrainedConfig):
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(
|
||||
self.rope_scaling) != 2:
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with two fields, "
|
||||
f"`type` and `factor`, got {self.rope_scaling}")
|
||||
f"`type` and `factor`, got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in [
|
||||
"linear", "dynamic"
|
||||
]:
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s type field must be one of ['linear', "
|
||||
f"'dynamic'], got {rope_scaling_type}")
|
||||
if rope_scaling_factor is None or not isinstance(
|
||||
rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
f"'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if (
|
||||
rope_scaling_factor is None
|
||||
or not isinstance(rope_scaling_factor, float)
|
||||
or rope_scaling_factor <= 1.0
|
||||
):
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s factor field must be a float > 1, got "
|
||||
f"{rope_scaling_factor}")
|
||||
f"{rope_scaling_factor}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user