diff --git a/vllm/config/model.py b/vllm/config/model.py index 1a39fb42e..d7ff55205 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1687,6 +1687,20 @@ class ModelConfig: def is_quantized(self) -> bool: return getattr(self.hf_config, "quantization_config", None) is not None + def is_nvfp4_quantized(self) -> bool: + # ModelOpt NVFP4 checkpoints resolve to modelopt_fp4 quantization method + if self.quantization in ("modelopt_fp4",): + return True + + # For Compressed Tensors we look for `"format": "nvfp4-pack-quantized"` + # in the quantization config + quant_config = self.model_arch_config.quantization_config + return ( + self.quantization == "compressed-tensors" + and quant_config is not None + and "nvfp4" in quant_config.get("format", "").lower() + ) + def get_served_model_name(model: str, served_model_name: str | list[str] | None): """ diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index d6f1202e5..63ce0f791 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -103,15 +103,21 @@ def enable_act_fusion(cfg: "VllmConfig") -> bool: def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: - """Enable if TP > 1 and Hopper+ and flashinfer installed.""" + """Enable if TP > 1 and Hopper/Blackwell and flashinfer installed.""" from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer return ( cfg.parallel_config.tensor_parallel_size > 1 and current_platform.is_cuda() - and current_platform.has_device_capability(90) and has_flashinfer() + and ( + current_platform.is_device_capability(100) + or current_platform.is_device_capability(90) + ) + # tp-dp combination broken: + # https://github.com/vllm-project/vllm/issues/34458 + and cfg.parallel_config.data_parallel_size == 1 ) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 27cf3a792..e67a77005 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -536,12 +536,34 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): ) -class DeepseekV32ForCausalLM(VerifyAndUpdateConfig): +class DeepseekV3ForCausalLM(VerifyAndUpdateConfig): + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """Disable AR-RMS-Quant fusion for DeepSeekV3 in NVFP4""" + # TODO: https://github.com/vllm-project/vllm/issues/34395 + + # disable AR-rms-fp4 fusion for DSv3+ + ar_rms_enabled = vllm_config.compilation_config.pass_config.fuse_allreduce_rms + nvfp4 = vllm_config.model_config.is_nvfp4_quantized() + + # Disable by default, warn if manually enabled: + if ar_rms_enabled is None and nvfp4: + vllm_config.compilation_config.pass_config.fuse_allreduce_rms = False + if ar_rms_enabled and nvfp4: + logger.warning( + "Allreduce-rms fusion broken for DeepSeekV3 with NVFP4 quant," + "see https://github.com/vllm-project/vllm/issues/34395." + ) + + +class DeepseekV32ForCausalLM(DeepseekV3ForCausalLM): @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32 """ + super().verify_and_update_config(vllm_config) + hf_config = vllm_config.model_config.hf_config # Mirror the check in vllm/model_executor/models/deepseek_v2.py @@ -632,6 +654,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "MambaForCausalLM": MambaModelConfig, "Mamba2ForCausalLM": MambaModelConfig, "FalconMambaForCausalLM": MambaModelConfig, + "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM, "NemotronHForCausalLM": NemotronHForCausalLMConfig, "NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,