From eea3024f43e06ea4e037ec86464dcc249d0c0b44 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 12 Feb 2026 22:48:42 -0800 Subject: [PATCH] [Bugfix] Fix mamba state dtype setting for Qwen3-Next and Qwen3.5 (#34489) Signed-off-by: Roger Wang --- .../layers/mamba/mamba_utils.py | 6 ++-- vllm/model_executor/models/config.py | 29 +++++++++++++++++++ vllm/model_executor/models/qwen3_5.py | 5 ++-- vllm/model_executor/models/qwen3_next.py | 8 +++-- 4 files changed, 42 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 7181ada1c..d66dee7c9 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -80,9 +80,11 @@ class MambaStateDtypeCalculator: cls, model_dtype: ModelDType | torch.dtype, mamba_cache_dtype: MambaDType, + mamba_ssm_cache_dtype: MambaDType = "auto", ) -> tuple[torch.dtype, torch.dtype]: - state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) - return (state_dtype, state_dtype) + return cls._mamba_state_dtype( + model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype + ) @classmethod def kda_state_dtype( diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index a6c244b6e..749a97d0a 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -582,6 +582,33 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig): cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype +class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + """Update mamba_ssm_cache_dtype for Qwen3.5 models when set to 'auto' + (or not explicitly set), to the value specified in the HF config's + mamba_ssm_dtype field. Warn if the user explicitly overrides it to a + different value. + """ + cache_config = vllm_config.cache_config + hf_text_config = vllm_config.model_config.hf_text_config + mamba_ssm_dtype = getattr(hf_text_config, "mamba_ssm_dtype", None) + if cache_config.mamba_ssm_cache_dtype == "auto": + if mamba_ssm_dtype is not None: + cache_config.mamba_ssm_cache_dtype = mamba_ssm_dtype + elif ( + mamba_ssm_dtype is not None + and cache_config.mamba_ssm_cache_dtype != mamba_ssm_dtype + ): + logger.warning( + "Qwen3.5 model specifies mamba_ssm_dtype='%s' in its config, " + "but --mamba-ssm-cache-dtype='%s' was passed. " + "Using the user-specified value.", + mamba_ssm_dtype, + cache_config.mamba_ssm_cache_dtype, + ) + + class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_model_config(model_config: "ModelConfig") -> None: @@ -611,5 +638,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM, "NemotronHForCausalLM": NemotronHForCausalLMConfig, "NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig, + "Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig, + "Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig, "VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig, } diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index c317c1e1a..55eb3408d 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -870,9 +870,10 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - mamba_ssm_dtype = vllm_config.model_config.hf_text_config.mamba_ssm_dtype return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - vllm_config.model_config.dtype, mamba_ssm_dtype + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, ) @classmethod diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index d0c13dd49..6da5bca1b 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -341,7 +341,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - self.model_config.dtype, self.cache_config.mamba_cache_dtype + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, ) def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: @@ -1372,7 +1374,9 @@ class Qwen3NextForCausalLM( vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, ) @classmethod