[Model] Add PLaMo2 (#14323)

Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: shemmi <shemmi@preferred.jp>
Co-authored-by: Kento Nozawa <nzw0301@preferred.jp>
Co-authored-by: Hiroaki Mikami <mhiroaki@preferred.jp>
Co-authored-by: Calvin Metzger <metzger@preferred.jp>
This commit is contained in:
Shinichi Hemmi
2025-04-16 11:31:30 +09:00
committed by GitHub
parent fdcb850f14
commit 3badb0213b
9 changed files with 800 additions and 24 deletions

View File

@@ -2838,6 +2838,13 @@ def _get_and_verify_dtype(
else:
torch_dtype = config_dtype
if config.model_type == "plamo2":
logger.info(
"For PLaMo2, we cast models to bfloat16 instead of using "
"float16 by default. This is because float16 does not work."
)
torch_dtype = torch.bfloat16
from vllm.platforms import current_platform
if (current_platform.is_cpu()
and current_platform.get_cpu_architecture()
@@ -2867,6 +2874,11 @@ def _get_and_verify_dtype(
"using float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
elif dtype == "float16" and config.model_type == "plamo2":
logger.warning(
"For PLaMo2, using float16 is unstable and might cause "
"unexpected behavior. Please use bfloat16 or float32 instead.")
torch_dtype = torch.float16
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")