Make max_model_len configurable (#972)
This commit is contained in:
@@ -38,6 +38,8 @@ class ModelConfig:
|
||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||
for BF16 models.
|
||||
seed: Random seed for reproducibility.
|
||||
max_model_len: Maximum length of a sequence (including prompt and
|
||||
output). If None, will be derived from the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -50,6 +52,7 @@ class ModelConfig:
|
||||
load_format: str,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
max_model_len: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@@ -63,6 +66,16 @@ class ModelConfig:
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
self.max_model_len = None
|
||||
if max_model_len is not None:
|
||||
derived_max_model_len = self.get_max_model_len()
|
||||
if max_model_len > derived_max_model_len:
|
||||
logger.warning(
|
||||
f"User-specified max_model_len ({max_model_len}) is "
|
||||
f"greater than the derived max_model_len "
|
||||
f"({derived_max_model_len}). Make sure the value is "
|
||||
"correct and within the model context size.")
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
@@ -134,6 +147,8 @@ class ModelConfig:
|
||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_max_model_len(self) -> int:
|
||||
if self.max_model_len is not None:
|
||||
return self.max_model_len
|
||||
max_model_len = float("inf")
|
||||
possible_keys = [
|
||||
# OPT
|
||||
|
||||
Reference in New Issue
Block a user