[Misc] Standardize RoPE handling for Qwen2-VL (#9250)
This commit is contained in:
@@ -1739,16 +1739,10 @@ def _get_and_verify_max_len(
|
||||
|
||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
if rope_scaling is not None:
|
||||
if "type" in rope_scaling:
|
||||
rope_type = rope_scaling["type"]
|
||||
elif "rope_type" in rope_scaling:
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"rope_scaling must have a 'type' or 'rope_type' key.")
|
||||
# No need to consider "type" key because of patch_rope_scaling when
|
||||
# loading HF config
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
|
||||
# The correct one should be "longrope", kept "su" here
|
||||
# to be backward compatible
|
||||
if rope_type not in ("su", "longrope", "llama3"):
|
||||
if disable_sliding_window:
|
||||
# TODO(robertgshaw): Find a model that supports rope_scaling
|
||||
@@ -1758,11 +1752,10 @@ def _get_and_verify_max_len(
|
||||
"with rope_scaling. Please raise an issue so we can "
|
||||
"investigate.")
|
||||
|
||||
if rope_type == "mrope":
|
||||
scaling_factor = 1
|
||||
else:
|
||||
assert "factor" in rope_scaling
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
# NOTE: rope_type == "default" does not define factor
|
||||
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
|
||||
scaling_factor = rope_scaling.get("factor", 1.0)
|
||||
|
||||
if rope_type == "yarn":
|
||||
derived_max_model_len = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
|
||||
Reference in New Issue
Block a user