[Bugfix] Fix mamba state dtype setting for Qwen3-Next and Qwen3.5 (#34489)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -80,9 +80,11 @@ class MambaStateDtypeCalculator:
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType = "auto",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, state_dtype)
|
||||
return cls._mamba_state_dtype(
|
||||
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def kda_state_dtype(
|
||||
|
||||
@@ -582,6 +582,33 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
|
||||
|
||||
|
||||
class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
"""Update mamba_ssm_cache_dtype for Qwen3.5 models when set to 'auto'
|
||||
(or not explicitly set), to the value specified in the HF config's
|
||||
mamba_ssm_dtype field. Warn if the user explicitly overrides it to a
|
||||
different value.
|
||||
"""
|
||||
cache_config = vllm_config.cache_config
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
mamba_ssm_dtype = getattr(hf_text_config, "mamba_ssm_dtype", None)
|
||||
if cache_config.mamba_ssm_cache_dtype == "auto":
|
||||
if mamba_ssm_dtype is not None:
|
||||
cache_config.mamba_ssm_cache_dtype = mamba_ssm_dtype
|
||||
elif (
|
||||
mamba_ssm_dtype is not None
|
||||
and cache_config.mamba_ssm_cache_dtype != mamba_ssm_dtype
|
||||
):
|
||||
logger.warning(
|
||||
"Qwen3.5 model specifies mamba_ssm_dtype='%s' in its config, "
|
||||
"but --mamba-ssm-cache-dtype='%s' was passed. "
|
||||
"Using the user-specified value.",
|
||||
mamba_ssm_dtype,
|
||||
cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
|
||||
class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
@@ -611,5 +638,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
|
||||
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
|
||||
"NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
|
||||
"Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
|
||||
"Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
|
||||
"VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig,
|
||||
}
|
||||
|
||||
@@ -870,9 +870,10 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
mamba_ssm_dtype = vllm_config.model_config.hf_text_config.mamba_ssm_dtype
|
||||
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
||||
vllm_config.model_config.dtype, mamba_ssm_dtype
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -341,7 +341,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
||||
self.model_config.dtype, self.cache_config.mamba_cache_dtype
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
self.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
@@ -1372,7 +1374,9 @@ class Qwen3NextForCausalLM(
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
||||
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user