[Bugfix] Fix mamba state dtype setting for Qwen3-Next and Qwen3.5 (#34489)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2026-02-12 22:48:42 -08:00
committed by GitHub
parent 2f308214c0
commit eea3024f43
4 changed files with 42 additions and 6 deletions

View File

@@ -80,9 +80,11 @@ class MambaStateDtypeCalculator:
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType = "auto",
) -> tuple[torch.dtype, torch.dtype]:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype)
return cls._mamba_state_dtype(
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
)
@classmethod
def kda_state_dtype(

View File

@@ -582,6 +582,33 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""Update mamba_ssm_cache_dtype for Qwen3.5 models when set to 'auto'
(or not explicitly set), to the value specified in the HF config's
mamba_ssm_dtype field. Warn if the user explicitly overrides it to a
different value.
"""
cache_config = vllm_config.cache_config
hf_text_config = vllm_config.model_config.hf_text_config
mamba_ssm_dtype = getattr(hf_text_config, "mamba_ssm_dtype", None)
if cache_config.mamba_ssm_cache_dtype == "auto":
if mamba_ssm_dtype is not None:
cache_config.mamba_ssm_cache_dtype = mamba_ssm_dtype
elif (
mamba_ssm_dtype is not None
and cache_config.mamba_ssm_cache_dtype != mamba_ssm_dtype
):
logger.warning(
"Qwen3.5 model specifies mamba_ssm_dtype='%s' in its config, "
"but --mamba-ssm-cache-dtype='%s' was passed. "
"Using the user-specified value.",
mamba_ssm_dtype,
cache_config.mamba_ssm_cache_dtype,
)
class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
@@ -611,5 +638,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
"NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
"Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
"Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
"VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig,
}

View File

@@ -870,9 +870,10 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
mamba_ssm_dtype = vllm_config.model_config.hf_text_config.mamba_ssm_dtype
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
vllm_config.model_config.dtype, mamba_ssm_dtype
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod

View File

@@ -341,7 +341,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
self.model_config.dtype, self.cache_config.mamba_cache_dtype
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
@@ -1372,7 +1374,9 @@ class Qwen3NextForCausalLM(
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod