NemotronH default mamba_ssm_cache_dtype=float32; enable auto-hook for NemotronHNanoVLV2Config (#39032)
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
@@ -7,7 +7,9 @@ from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -346,17 +348,20 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
|
||||
class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
DEFAULT_MAMBA_SSM_CACHE_DTYPE = "float32"
|
||||
"""Only `float32` is known to have no accuracy issues by default."""
|
||||
|
||||
@classmethod
|
||||
def update_mamba_ssm_cache_dtype(
|
||||
cls, *, cache_config: "CacheConfig", hf_config: "PretrainedConfig"
|
||||
) -> None:
|
||||
"""Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
|
||||
(or not explicitly set), to the value specified in the HF config, or to
|
||||
float16 if not specified.
|
||||
`float32` if not specified.
|
||||
"""
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config.mamba_ssm_cache_dtype == "auto":
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
mamba_ssm_cache_dtype = getattr(
|
||||
hf_config, "mamba_ssm_cache_dtype", "float16"
|
||||
hf_config, "mamba_ssm_cache_dtype", cls.DEFAULT_MAMBA_SSM_CACHE_DTYPE
|
||||
)
|
||||
logger.info(
|
||||
"Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
|
||||
@@ -364,8 +369,22 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
)
|
||||
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
|
||||
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
cls.update_mamba_ssm_cache_dtype(
|
||||
cache_config=vllm_config.cache_config,
|
||||
hf_config=vllm_config.model_config.hf_config,
|
||||
)
|
||||
|
||||
|
||||
class NemotronHNanoVLV2Config(VerifyAndUpdateConfig):
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
NemotronHForCausalLMConfig.update_mamba_ssm_cache_dtype(
|
||||
cache_config=vllm_config.cache_config,
|
||||
hf_config=vllm_config.model_config.hf_config.text_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
mm_config = model_config.multimodal_config
|
||||
|
||||
Reference in New Issue
Block a user