[ROCm][Quantization] make quark ocp mx dtype parser robust for weight-only quantization (#36232)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
This commit is contained in:
xuebwang-amd
2026-03-20 23:10:03 +08:00
committed by GitHub
parent 8b6c6b9505
commit 44eea10f68
2 changed files with 19 additions and 5 deletions

View File

@@ -92,7 +92,8 @@ class QuarkMoEMethod(FusedMoEMethodBase):
rocm_aiter_ops.is_fused_moe_enabled()
)
if (
input_config.get("dtype") == "fp8_e4m3"
input_config is not None
and input_config.get("dtype") == "fp8_e4m3"
and not input_config.get("is_dynamic")
and not emulate
):

View File

@@ -176,7 +176,7 @@ class QuarkOCP_MX(QuarkScheme):
def __init__(
self,
weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any] | None,
dynamic_mxfp4_quant: bool = False,
):
self.out_dtype = torch.get_default_dtype()
@@ -185,7 +185,13 @@ class QuarkOCP_MX(QuarkScheme):
self.input_quant_spec = input_quant_spec
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
self.input_dtype: str | None = None
if input_quant_spec is not None:
input_quant = input_quant_spec["dtype"]
if input_quant == "fp8_e4m3":
self.input_dtype = "fp8"
else:
self.input_dtype = input_quant.replace("fp", "mxfp")
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self.input_dtype, self.weight_dtype
@@ -200,14 +206,21 @@ class QuarkOCP_MX(QuarkScheme):
dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "")
)
if self.input_dtype == "mxfp4":
if self.input_dtype is None:
self.quant_dequant_func: Callable[[torch.Tensor], torch.Tensor] = (
lambda x: x
) # no input Q/DQ for weight-only
elif self.input_dtype == "mxfp4":
self.quant_dequant_func = quant_dequant_mxfp4
else:
self.quant_dequant_func = partial(
quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "")
)
self.static_input_scales = not input_quant_spec.get("is_dynamic")
if input_quant_spec is None:
self.static_input_scales = False
else:
self.static_input_scales = not input_quant_spec.get("is_dynamic")
if self.static_input_scales:
raise NotImplementedError(