diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 0a5db4e71..4ebf8c439 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -92,7 +92,8 @@ class QuarkMoEMethod(FusedMoEMethodBase): rocm_aiter_ops.is_fused_moe_enabled() ) if ( - input_config.get("dtype") == "fp8_e4m3" + input_config is not None + and input_config.get("dtype") == "fp8_e4m3" and not input_config.get("is_dynamic") and not emulate ): diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index 6917bb6f2..1b30f5b82 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -176,7 +176,7 @@ class QuarkOCP_MX(QuarkScheme): def __init__( self, weight_quant_spec: dict[str, Any], - input_quant_spec: dict[str, Any], + input_quant_spec: dict[str, Any] | None, dynamic_mxfp4_quant: bool = False, ): self.out_dtype = torch.get_default_dtype() @@ -185,7 +185,13 @@ class QuarkOCP_MX(QuarkScheme): self.input_quant_spec = input_quant_spec self.dynamic_mxfp4_quant = dynamic_mxfp4_quant self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp") - self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp") + self.input_dtype: str | None = None + if input_quant_spec is not None: + input_quant = input_quant_spec["dtype"] + if input_quant == "fp8_e4m3": + self.input_dtype = "fp8" + else: + self.input_dtype = input_quant.replace("fp", "mxfp") self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( self.input_dtype, self.weight_dtype @@ -200,14 +206,21 @@ class QuarkOCP_MX(QuarkScheme): dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "") ) - if self.input_dtype == "mxfp4": + if self.input_dtype is None: + self.quant_dequant_func: Callable[[torch.Tensor], torch.Tensor] = ( + lambda x: x + ) # no input Q/DQ for weight-only + elif self.input_dtype == "mxfp4": self.quant_dequant_func = quant_dequant_mxfp4 else: self.quant_dequant_func = partial( quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "") ) - self.static_input_scales = not input_quant_spec.get("is_dynamic") + if input_quant_spec is None: + self.static_input_scales = False + else: + self.static_input_scales = not input_quant_spec.get("is_dynamic") if self.static_input_scales: raise NotImplementedError(