[Bugfix] Fix mamba state dtype setting for Qwen3-Next and Qwen3.5 (#34489)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -80,9 +80,11 @@ class MambaStateDtypeCalculator:
|
|||||||
cls,
|
cls,
|
||||||
model_dtype: ModelDType | torch.dtype,
|
model_dtype: ModelDType | torch.dtype,
|
||||||
mamba_cache_dtype: MambaDType,
|
mamba_cache_dtype: MambaDType,
|
||||||
|
mamba_ssm_cache_dtype: MambaDType = "auto",
|
||||||
) -> tuple[torch.dtype, torch.dtype]:
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
return cls._mamba_state_dtype(
|
||||||
return (state_dtype, state_dtype)
|
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def kda_state_dtype(
|
def kda_state_dtype(
|
||||||
|
|||||||
@@ -582,6 +582,33 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
|
|||||||
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
|
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig):
|
||||||
|
@staticmethod
|
||||||
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
|
"""Update mamba_ssm_cache_dtype for Qwen3.5 models when set to 'auto'
|
||||||
|
(or not explicitly set), to the value specified in the HF config's
|
||||||
|
mamba_ssm_dtype field. Warn if the user explicitly overrides it to a
|
||||||
|
different value.
|
||||||
|
"""
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
hf_text_config = vllm_config.model_config.hf_text_config
|
||||||
|
mamba_ssm_dtype = getattr(hf_text_config, "mamba_ssm_dtype", None)
|
||||||
|
if cache_config.mamba_ssm_cache_dtype == "auto":
|
||||||
|
if mamba_ssm_dtype is not None:
|
||||||
|
cache_config.mamba_ssm_cache_dtype = mamba_ssm_dtype
|
||||||
|
elif (
|
||||||
|
mamba_ssm_dtype is not None
|
||||||
|
and cache_config.mamba_ssm_cache_dtype != mamba_ssm_dtype
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Qwen3.5 model specifies mamba_ssm_dtype='%s' in its config, "
|
||||||
|
"but --mamba-ssm-cache-dtype='%s' was passed. "
|
||||||
|
"Using the user-specified value.",
|
||||||
|
mamba_ssm_dtype,
|
||||||
|
cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
|
class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
@@ -611,5 +638,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
|||||||
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
|
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
|
||||||
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
|
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
|
||||||
"NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
|
"NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
|
||||||
|
"Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
|
||||||
|
"Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
|
||||||
"VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig,
|
"VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -870,9 +870,10 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
|
|||||||
cls,
|
cls,
|
||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
) -> tuple[torch.dtype, torch.dtype]:
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
mamba_ssm_dtype = vllm_config.model_config.hf_text_config.mamba_ssm_dtype
|
|
||||||
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
||||||
vllm_config.model_config.dtype, mamba_ssm_dtype
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -341,7 +341,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
|
|
||||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||||
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
||||||
self.model_config.dtype, self.cache_config.mamba_cache_dtype
|
self.model_config.dtype,
|
||||||
|
self.cache_config.mamba_cache_dtype,
|
||||||
|
self.cache_config.mamba_ssm_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
@@ -1372,7 +1374,9 @@ class Qwen3NextForCausalLM(
|
|||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
) -> tuple[torch.dtype, torch.dtype]:
|
) -> tuple[torch.dtype, torch.dtype]:
|
||||||
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
||||||
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
|
vllm_config.model_config.dtype,
|
||||||
|
vllm_config.cache_config.mamba_cache_dtype,
|
||||||
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user