[Misc] Fix skipped max-model-len validation when deriving max model length from tokenizer config (#19660)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Ye (Charlotte) Qi
2025-06-15 23:30:29 -07:00
committed by GitHub
parent 367871a469
commit b692e9cd07
2 changed files with 43 additions and 13 deletions

View File

@@ -1429,25 +1429,19 @@ class ModelConfig:
return getattr(self.hf_config, "matryoshka_dimensions", None)
def get_and_verify_max_len(self, max_model_len: int):
tokenizer_config = try_get_tokenizer_config(
self.tokenizer,
trust_remote_code=self.trust_remote_code,
revision=self.tokenizer_revision)
max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
tokenizer_config=tokenizer_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=self.spec_target_max_model_len,
encoder_config=self.encoder_config)
tokenizer_config = try_get_tokenizer_config(
self.tokenizer,
trust_remote_code=self.trust_remote_code,
revision=self.tokenizer_revision)
if tokenizer_config is None:
return max_model_len
model_max_length = tokenizer_config.get("model_max_length",
max_model_len)
max_model_len = min(max_model_len, model_max_length)
logger.info("Using max model len %s", max_model_len)
return max_model_len
@@ -3283,6 +3277,7 @@ def _get_and_verify_dtype(
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
tokenizer_config: Optional[dict],
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
@@ -3309,7 +3304,7 @@ def _get_and_verify_max_len(
"max_seq_length",
"seq_len",
]
# Choose the smallest "max_length" from the possible keys.
# Choose the smallest "max_length" from the possible keys
max_len_key = None
for key in possible_keys:
max_len = getattr(hf_config, key, None)
@@ -3332,6 +3327,13 @@ def _get_and_verify_max_len(
derived_max_model_len = min(derived_max_model_len,
sliding_window_len_min)
# Consider model_max_length in tokenizer_config
if tokenizer_config:
tokenizer_model_max_length = tokenizer_config.get(
"model_max_length", derived_max_model_len)
derived_max_model_len = min(derived_max_model_len,
tokenizer_model_max_length)
# If none of the keys were found in the config, use a default and
# log a warning.
if derived_max_model_len == float("inf"):