[Bugfix] Limit the default value of max_model_len when it is not specified by users (#27556)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -608,6 +608,13 @@ class Platform:
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def check_max_model_len(cls, max_model_len: int) -> int:
|
||||
"""
|
||||
Check max_model_len for the current platform.
|
||||
"""
|
||||
return max_model_len
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
@@ -251,6 +251,22 @@ class TpuPlatform(Platform):
|
||||
def use_sync_weight_loader(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def check_max_model_len(cls, max_model_len: int) -> int:
|
||||
"""
|
||||
Check max_model_len for the current platform.
|
||||
"""
|
||||
logger.warning(
|
||||
"--max-model-len is not specified, "
|
||||
"it's currently using model's default length %d, "
|
||||
"which might be too large."
|
||||
"Please input with --max-model-len based on your "
|
||||
"request input length and output length, to avoid "
|
||||
"unnecessary degradation.",
|
||||
max_model_len,
|
||||
)
|
||||
return max_model_len
|
||||
|
||||
|
||||
try:
|
||||
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
|
||||
|
||||
Reference in New Issue
Block a user