[Misc] Fix skipped max-model-len validation when deriving max model length from tokenizer config (#19660)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
committed by
GitHub
parent
367871a469
commit
b692e9cd07
@@ -1429,25 +1429,19 @@ class ModelConfig:
|
||||
return getattr(self.hf_config, "matryoshka_dimensions", None)
|
||||
|
||||
def get_and_verify_max_len(self, max_model_len: int):
|
||||
tokenizer_config = try_get_tokenizer_config(
|
||||
self.tokenizer,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
revision=self.tokenizer_revision)
|
||||
max_model_len = _get_and_verify_max_len(
|
||||
hf_config=self.hf_text_config,
|
||||
tokenizer_config=tokenizer_config,
|
||||
max_model_len=max_model_len,
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
sliding_window_len=self.get_hf_config_sliding_window(),
|
||||
spec_target_max_model_len=self.spec_target_max_model_len,
|
||||
encoder_config=self.encoder_config)
|
||||
|
||||
tokenizer_config = try_get_tokenizer_config(
|
||||
self.tokenizer,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
revision=self.tokenizer_revision)
|
||||
|
||||
if tokenizer_config is None:
|
||||
return max_model_len
|
||||
|
||||
model_max_length = tokenizer_config.get("model_max_length",
|
||||
max_model_len)
|
||||
max_model_len = min(max_model_len, model_max_length)
|
||||
logger.info("Using max model len %s", max_model_len)
|
||||
return max_model_len
|
||||
|
||||
|
||||
@@ -3283,6 +3277,7 @@ def _get_and_verify_dtype(
|
||||
|
||||
def _get_and_verify_max_len(
|
||||
hf_config: PretrainedConfig,
|
||||
tokenizer_config: Optional[dict],
|
||||
max_model_len: Optional[int],
|
||||
disable_sliding_window: bool,
|
||||
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
|
||||
@@ -3309,7 +3304,7 @@ def _get_and_verify_max_len(
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
# Choose the smallest "max_length" from the possible keys.
|
||||
# Choose the smallest "max_length" from the possible keys
|
||||
max_len_key = None
|
||||
for key in possible_keys:
|
||||
max_len = getattr(hf_config, key, None)
|
||||
@@ -3332,6 +3327,13 @@ def _get_and_verify_max_len(
|
||||
derived_max_model_len = min(derived_max_model_len,
|
||||
sliding_window_len_min)
|
||||
|
||||
# Consider model_max_length in tokenizer_config
|
||||
if tokenizer_config:
|
||||
tokenizer_model_max_length = tokenizer_config.get(
|
||||
"model_max_length", derived_max_model_len)
|
||||
derived_max_model_len = min(derived_max_model_len,
|
||||
tokenizer_model_max_length)
|
||||
|
||||
# If none of the keys were found in the config, use a default and
|
||||
# log a warning.
|
||||
if derived_max_model_len == float("inf"):
|
||||
|
||||
Reference in New Issue
Block a user