NemotronH default mamba_ssm_cache_dtype=float32; enable auto-hook for NemotronHNanoVLV2Config (#39032)

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
Netanel Haber
2026-04-06 22:47:46 +03:00
committed by GitHub
parent e8ebbdde83
commit dfa5062a8f

View File

@@ -7,7 +7,9 @@ from vllm.logger import init_logger
from vllm.utils.math_utils import round_up
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
logger = init_logger(__name__)
@@ -346,17 +348,20 @@ class MambaModelConfig(VerifyAndUpdateConfig):
class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
DEFAULT_MAMBA_SSM_CACHE_DTYPE = "float32"
"""Only `float32` is known to have no accuracy issues by default."""
@classmethod
def update_mamba_ssm_cache_dtype(
cls, *, cache_config: "CacheConfig", hf_config: "PretrainedConfig"
) -> None:
"""Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
(or not explicitly set), to the value specified in the HF config, or to
float16 if not specified.
`float32` if not specified.
"""
cache_config = vllm_config.cache_config
if cache_config.mamba_ssm_cache_dtype == "auto":
hf_config = vllm_config.model_config.hf_config
mamba_ssm_cache_dtype = getattr(
hf_config, "mamba_ssm_cache_dtype", "float16"
hf_config, "mamba_ssm_cache_dtype", cls.DEFAULT_MAMBA_SSM_CACHE_DTYPE
)
logger.info(
"Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
@@ -364,8 +369,22 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
)
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
cls.update_mamba_ssm_cache_dtype(
cache_config=vllm_config.cache_config,
hf_config=vllm_config.model_config.hf_config,
)
class NemotronHNanoVLV2Config(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
NemotronHForCausalLMConfig.update_mamba_ssm_cache_dtype(
cache_config=vllm_config.cache_config,
hf_config=vllm_config.model_config.hf_config.text_config,
)
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
mm_config = model_config.multimodal_config