[Model] Add PLaMo2 (#14323)
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Signed-off-by: shemmi <shemmi@preferred.jp> Co-authored-by: Kento Nozawa <nzw0301@preferred.jp> Co-authored-by: Hiroaki Mikami <mhiroaki@preferred.jp> Co-authored-by: Calvin Metzger <metzger@preferred.jp>
This commit is contained in:
@@ -2838,6 +2838,13 @@ def _get_and_verify_dtype(
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
|
||||
if config.model_type == "plamo2":
|
||||
logger.info(
|
||||
"For PLaMo2, we cast models to bfloat16 instead of using "
|
||||
"float16 by default. This is because float16 does not work."
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if (current_platform.is_cpu()
|
||||
and current_platform.get_cpu_architecture()
|
||||
@@ -2867,6 +2874,11 @@ def _get_and_verify_dtype(
|
||||
"using float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16.")
|
||||
torch_dtype = torch.bfloat16
|
||||
elif dtype == "float16" and config.model_type == "plamo2":
|
||||
logger.warning(
|
||||
"For PLaMo2, using float16 is unstable and might cause "
|
||||
"unexpected behavior. Please use bfloat16 or float32 instead.")
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
Reference in New Issue
Block a user