[Bugfix] In LongRoPE, decide short vs long based on max_model_len (#27431)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -2142,8 +2142,18 @@ def _get_and_verify_max_len(
|
||||
# If the user didn't specify `max_model_len`, then use that derived from
|
||||
# the model config as a default value.
|
||||
if max_model_len is None:
|
||||
max_model_len = int(derived_max_model_len)
|
||||
# For LongRoPE, default to original_max_position_embeddings to avoid
|
||||
# performance degradation for shorter sequences
|
||||
if rope_scaling is not None and rope_scaling["rope_type"] == "longrope":
|
||||
max_model_len = int(
|
||||
getattr(
|
||||
hf_config, "original_max_position_embeddings", derived_max_model_len
|
||||
)
|
||||
)
|
||||
else:
|
||||
max_model_len = int(derived_max_model_len)
|
||||
max_model_len = current_platform.check_max_model_len(max_model_len)
|
||||
|
||||
# If the user specified a max length, make sure it is smaller than the
|
||||
# derived length from the HF model config.
|
||||
elif max_model_len > derived_max_model_len:
|
||||
|
||||
Reference in New Issue
Block a user