diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 36f20c89f..dedc7db38 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import ( ) from vllm.model_executor.models.utils import WeightsMapper from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig): self.kv_cache_group = kv_cache_group self.kv_cache_config = kv_cache_config self.pack_method = pack_method + self.dynamic_mxfp4_quant = False + + def maybe_update_config(self, model_name: str, revision: str | None = None): + self.hf_config = get_config( + model=model_name, + trust_remote_code=False, # or get from model_config if available + revision=revision, + config_format="auto", + ) + + quant_config = getattr(self.hf_config, "quantization_config", None) + if quant_config is not None: + quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"] + model_type = self.hf_config.model_type + if quant_dtype == "fp4" and model_type == "deepseek_v3": + self.dynamic_mxfp4_quant = True def get_linear_method(self) -> "QuarkLinearMethod": return QuarkLinearMethod(self) @@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig): if should_ignore_layer( prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping ): - return UnquantizedLinearMethod() + if ( + "self_attn" not in prefix # only quantize attention projections + or not getattr(self, "dynamic_mxfp4_quant", False) + or not isinstance(layer, LinearBase) # Ignore other methods + ): + return UnquantizedLinearMethod() + + scheme = self.get_scheme( + layer=layer, + layer_name=prefix, + dynamic_mxfp4_quant=True, + ) + layer.scheme = scheme + return QuarkLinearMethod(self) if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) layer.scheme = scheme @@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig): ) return global_quant_config - def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": + def _get_scheme_from_config( + self, config: dict[str, Any], dynamic_mxfp4_quant: bool = False + ) -> "QuarkScheme": if config.get("output_tensors") or config.get("bias"): raise NotImplementedError( "Currently, Quark models with output_tensors " @@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig): input_symmetric=input_config.get("symmetric"), ) elif self._is_w_ocp_mx_a_x(weight_config, input_config): - return QuarkOCP_MX(weight_config, input_config) + return QuarkOCP_MX( + weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant + ) raise NotImplementedError( "No quark compatible scheme was found. " @@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig): f"Input config: {input_config}" ) - def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme": + def get_scheme( + self, layer: torch.nn.Module, layer_name: str, dynamic_mxfp4_quant: bool = False + ) -> "QuarkScheme": layer_quant_config = self._find_matched_config(layer_name, layer) # Find the quant_scheme - scheme = self._get_scheme_from_config(layer_quant_config) + scheme = self._get_scheme_from_config( + layer_quant_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant + ) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) self._check_scheme_supported(scheme.get_min_capability()) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index c5f50122e..6917bb6f2 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, OCP_MX_Scheme, ) -from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PackedvLLMParameter, +) +from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from .quark_scheme import QuarkScheme @@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError): class QuarkOCP_MX(QuarkScheme): def __init__( - self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any] + self, + weight_quant_spec: dict[str, Any], + input_quant_spec: dict[str, Any], + dynamic_mxfp4_quant: bool = False, ): self.out_dtype = torch.get_default_dtype() self.qscheme = "per_group" self.weight_quant_spec = weight_quant_spec self.input_quant_spec = input_quant_spec - + self.dynamic_mxfp4_quant = dynamic_mxfp4_quant self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp") self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp") @@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme): layer.weight_scale.data, requires_grad=False ) else: - if self.rocm_use_aiter_fp4_asm_gemm: + if self.dynamic_mxfp4_quant: + w_q, w_s = dynamic_mxfp4_quant(layer.weight) + layer.weight_scale = torch.nn.Parameter( + w_s.T.contiguous(), requires_grad=False + ) + layer.weight = torch.nn.Parameter(w_q, requires_grad=False) + elif self.rocm_use_aiter_fp4_asm_gemm: # shuffle weight scale weight_scale_shuffle = layer.weight_scale.data sm, sn = weight_scale_shuffle.shape @@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme): weight_loader: Callable, **kwargs, ): - output_size_per_partition = sum(output_partition_sizes) - layer.logical_widths = output_partition_sizes + if self.dynamic_mxfp4_quant: + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) - # WEIGHT - weight = PackedvLLMParameter( - data=torch.empty( - output_size_per_partition, - self.get_packed_dim(input_size_per_partition, self.weight_dtype), - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - packed_dim=1, - packed_factor=self.packed_factor, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, kwargs) + else: + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes - # WEIGHT SCALE - weight_scale = GroupQuantScaleParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition // OCP_MX_BLOCK_SIZE, - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight_scale", weight_scale) + # WEIGHT + weight = PackedvLLMParameter( + data=torch.empty( + output_size_per_partition, + self.get_packed_dim(input_size_per_partition, self.weight_dtype), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=self.packed_factor, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) def apply_weights( self,