[ROCm][Quantization] make quark ocp mx dtype parser robust for weight-only quantization (#36232)
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
This commit is contained in:
@@ -92,7 +92,8 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
if (
|
||||
input_config.get("dtype") == "fp8_e4m3"
|
||||
input_config is not None
|
||||
and input_config.get("dtype") == "fp8_e4m3"
|
||||
and not input_config.get("is_dynamic")
|
||||
and not emulate
|
||||
):
|
||||
|
||||
@@ -176,7 +176,7 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
def __init__(
|
||||
self,
|
||||
weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any] | None,
|
||||
dynamic_mxfp4_quant: bool = False,
|
||||
):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
@@ -185,7 +185,13 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
self.input_quant_spec = input_quant_spec
|
||||
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
|
||||
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype: str | None = None
|
||||
if input_quant_spec is not None:
|
||||
input_quant = input_quant_spec["dtype"]
|
||||
if input_quant == "fp8_e4m3":
|
||||
self.input_dtype = "fp8"
|
||||
else:
|
||||
self.input_dtype = input_quant.replace("fp", "mxfp")
|
||||
|
||||
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
|
||||
self.input_dtype, self.weight_dtype
|
||||
@@ -200,14 +206,21 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "")
|
||||
)
|
||||
|
||||
if self.input_dtype == "mxfp4":
|
||||
if self.input_dtype is None:
|
||||
self.quant_dequant_func: Callable[[torch.Tensor], torch.Tensor] = (
|
||||
lambda x: x
|
||||
) # no input Q/DQ for weight-only
|
||||
elif self.input_dtype == "mxfp4":
|
||||
self.quant_dequant_func = quant_dequant_mxfp4
|
||||
else:
|
||||
self.quant_dequant_func = partial(
|
||||
quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "")
|
||||
)
|
||||
|
||||
self.static_input_scales = not input_quant_spec.get("is_dynamic")
|
||||
if input_quant_spec is None:
|
||||
self.static_input_scales = False
|
||||
else:
|
||||
self.static_input_scales = not input_quant_spec.get("is_dynamic")
|
||||
|
||||
if self.static_input_scales:
|
||||
raise NotImplementedError(
|
||||
|
||||
Reference in New Issue
Block a user