Add util function for checking nesting of rope parameters (#31146)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -15,7 +15,6 @@ from huggingface_hub import (
|
||||
)
|
||||
from packaging.version import Version
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
||||
from transformers.models.auto.image_processing_auto import get_image_processor_config
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
@@ -44,6 +43,16 @@ from .repo_utils import (
|
||||
with_retry,
|
||||
)
|
||||
|
||||
try:
|
||||
# Transformers v5
|
||||
from transformers.configuration_utils import ALLOWED_ATTENTION_LAYER_TYPES
|
||||
except ImportError:
|
||||
# Transformers v4
|
||||
from transformers.configuration_utils import (
|
||||
ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES,
|
||||
)
|
||||
|
||||
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
from modelscope import AutoConfig
|
||||
else:
|
||||
@@ -104,6 +113,14 @@ _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = {
|
||||
}
|
||||
|
||||
|
||||
def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool:
|
||||
"""Check if rope_parameters is nested by layer types."""
|
||||
# Cannot be nested if rope_parameters is empty
|
||||
if not rope_parameters:
|
||||
return False
|
||||
return set(rope_parameters.keys()).issubset(ALLOWED_ATTENTION_LAYER_TYPES)
|
||||
|
||||
|
||||
class HFConfigParser(ConfigParserBase):
|
||||
def parse(
|
||||
self,
|
||||
@@ -346,7 +363,7 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
|
||||
config.rope_parameters["original_max_position_embeddings"] = ompe
|
||||
|
||||
# Handle nested rope_parameters in interleaved sliding attention models
|
||||
if set(config.rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
|
||||
if is_rope_parameters_nested(config.rope_parameters):
|
||||
for rope_parameters_layer_type in config.rope_parameters.values():
|
||||
patch_rope_parameters_dict(rope_parameters_layer_type)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user