[ROCm] Add dynamic mxfp4 quantization for DeepSeek V2 projection layers (#34157)
Signed-off-by: Doug Lehr <douglehr@amd.com> Signed-off-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
This commit is contained in:
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig):
|
||||
self.kv_cache_group = kv_cache_group
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.pack_method = pack_method
|
||||
self.dynamic_mxfp4_quant = False
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
self.hf_config = get_config(
|
||||
model=model_name,
|
||||
trust_remote_code=False, # or get from model_config if available
|
||||
revision=revision,
|
||||
config_format="auto",
|
||||
)
|
||||
|
||||
quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||
if quant_config is not None:
|
||||
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
|
||||
model_type = self.hf_config.model_type
|
||||
if quant_dtype == "fp4" and model_type == "deepseek_v3":
|
||||
self.dynamic_mxfp4_quant = True
|
||||
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
@@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig):
|
||||
if should_ignore_layer(
|
||||
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
if (
|
||||
"self_attn" not in prefix # only quantize attention projections
|
||||
or not getattr(self, "dynamic_mxfp4_quant", False)
|
||||
or not isinstance(layer, LinearBase) # Ignore other methods
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
scheme = self.get_scheme(
|
||||
layer=layer,
|
||||
layer_name=prefix,
|
||||
dynamic_mxfp4_quant=True,
|
||||
)
|
||||
layer.scheme = scheme
|
||||
return QuarkLinearMethod(self)
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
layer.scheme = scheme
|
||||
@@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig):
|
||||
)
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||
def _get_scheme_from_config(
|
||||
self, config: dict[str, Any], dynamic_mxfp4_quant: bool = False
|
||||
) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
@@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig):
|
||||
input_symmetric=input_config.get("symmetric"),
|
||||
)
|
||||
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX(weight_config, input_config)
|
||||
return QuarkOCP_MX(
|
||||
weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"No quark compatible scheme was found. "
|
||||
@@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig):
|
||||
f"Input config: {input_config}"
|
||||
)
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
|
||||
def get_scheme(
|
||||
self, layer: torch.nn.Module, layer_name: str, dynamic_mxfp4_quant: bool = False
|
||||
) -> "QuarkScheme":
|
||||
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||
scheme = self._get_scheme_from_config(
|
||||
layer_quant_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
|
||||
)
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
|
||||
@@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
)
|
||||
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
@@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError):
|
||||
|
||||
class QuarkOCP_MX(QuarkScheme):
|
||||
def __init__(
|
||||
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
|
||||
self,
|
||||
weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any],
|
||||
dynamic_mxfp4_quant: bool = False,
|
||||
):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
|
||||
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
|
||||
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
|
||||
@@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
else:
|
||||
if self.rocm_use_aiter_fp4_asm_gemm:
|
||||
if self.dynamic_mxfp4_quant:
|
||||
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
w_s.T.contiguous(), requires_grad=False
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
|
||||
elif self.rocm_use_aiter_fp4_asm_gemm:
|
||||
# shuffle weight scale
|
||||
weight_scale_shuffle = layer.weight_scale.data
|
||||
sm, sn = weight_scale_shuffle.shape
|
||||
@@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
if self.dynamic_mxfp4_quant:
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, kwargs)
|
||||
else:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user