[Frontend][Model] Add 'float16' to possible mamba cache dtype values, override mamba SSM cache dtype value for NemotronH (#29978)
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
@@ -485,6 +485,26 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
|
||||
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
|
||||
|
||||
|
||||
class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
"""Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
|
||||
(or not explicitly set), to the value specified in the HF config, or to
|
||||
float16 if not specified.
|
||||
"""
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config.mamba_ssm_cache_dtype == "auto":
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
mamba_ssm_cache_dtype = getattr(
|
||||
hf_config, "mamba_ssm_cache_dtype", "float16"
|
||||
)
|
||||
logger.info(
|
||||
"Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
|
||||
mamba_ssm_cache_dtype,
|
||||
)
|
||||
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
|
||||
|
||||
|
||||
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"GteModel": SnowflakeGteNewModelConfig,
|
||||
"GteNewModel": GteNewModelConfig,
|
||||
@@ -502,4 +522,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"Mamba2ForCausalLM": MambaModelConfig,
|
||||
"FalconMambaForCausalLM": MambaModelConfig,
|
||||
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
|
||||
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user