[ROCm][Quantization] GPT_OSS in amd-quark format model loading and emulations (#29008)
Signed-off-by: xuebwang-amd <xuebwang@amd.com> Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
@@ -298,12 +298,18 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
|
||||
pc2,
|
||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_scale=pc1,
|
||||
w2_scale=pc2,
|
||||
)
|
||||
if a_dtype == "bf16" and w_dtype == "mx4":
|
||||
quant_config = mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale=pc1,
|
||||
w2_scale=pc2,
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Quantization configuration for activation={a_dtype} and weight={w_dtype} "
|
||||
f"has not been implemented."
|
||||
)
|
||||
|
||||
out_triton_monolithic = triton_kernel_moe_forward(
|
||||
hidden_states=x_tri,
|
||||
|
||||
110
tests/models/quantization/test_gpt_oss.py
Normal file
110
tests/models/quantization/test_gpt_oss.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
End-to-end accuracy test for GPT-OSS model quantization.
|
||||
|
||||
Config:
|
||||
Task: gsm8k_platinum
|
||||
Filter: flexible-extract
|
||||
n-shot: 5
|
||||
Metric: exact_match
|
||||
|
||||
Run: pytest tests/models/quantization/test_gpt_oss.py
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
from dataclasses import dataclass
|
||||
|
||||
import huggingface_hub
|
||||
import lm_eval
|
||||
import pytest
|
||||
from packaging import version
|
||||
|
||||
MODEL_ACCURACIES = {
|
||||
# Full quantization: attention linears and MoE linears
|
||||
"amd/gpt-oss-20b-WFP8-AFP8-KVFP8": 0.89,
|
||||
# MoE linears only quantization
|
||||
"amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8": 0.89,
|
||||
# MoE linears only quantization
|
||||
# "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-MXFP4-KV-FP8": 0.90,
|
||||
}
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")
|
||||
) >= version.parse("0.9.0")
|
||||
|
||||
|
||||
def has_huggingface_access(repo):
|
||||
try:
|
||||
huggingface_hub.list_repo_refs(repo)
|
||||
return True
|
||||
except huggingface_hub.errors.RepositoryNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
HF_HUB_AMD_ORG_ACCESS = all(
|
||||
[has_huggingface_access(model_name) for model_name in MODEL_ACCURACIES]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCase:
|
||||
model_id: str
|
||||
tp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationConfig:
|
||||
model_name: str
|
||||
|
||||
def get_model_args(self, tp_size: int):
|
||||
return {
|
||||
"pretrained": self.model_name,
|
||||
"chat_template_args": {"reasoning_effort": "low"},
|
||||
"enable_thinking": True,
|
||||
"think_end_token": "200008",
|
||||
"tensor_parallel_size": tp_size,
|
||||
"dtype": "auto",
|
||||
"gpu_memory_utilization": 0.95,
|
||||
"trust_remote_code": False,
|
||||
"enable_prefix_caching": False,
|
||||
"enforce_eager": False,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.skipif(
|
||||
not HF_HUB_AMD_ORG_ACCESS,
|
||||
reason="Read access to huggingface.co/amd is required for this test.",
|
||||
)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("model_name, expected_accuracy", MODEL_ACCURACIES.items())
|
||||
def test_gpt_oss_attention_quantization(
|
||||
model_name: str, tp_size: int, expected_accuracy: float
|
||||
):
|
||||
model_args = EvaluationConfig(model_name).get_model_args(tp_size)
|
||||
|
||||
extra_run_kwargs = {
|
||||
"gen_kwargs": {"max_gen_toks": 8000},
|
||||
"apply_chat_template": True,
|
||||
"fewshot_as_multiturn": True,
|
||||
"num_fewshot": 5,
|
||||
}
|
||||
|
||||
lm_eval_out = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=model_args,
|
||||
tasks="gsm8k_platinum",
|
||||
batch_size="auto",
|
||||
**extra_run_kwargs,
|
||||
)
|
||||
measured_accuracy = float(
|
||||
lm_eval_out["results"]["gsm8k_platinum"]["exact_match,flexible-extract"]
|
||||
)
|
||||
|
||||
rtol = 0.02
|
||||
assert (
|
||||
measured_accuracy - rtol < expected_accuracy
|
||||
and measured_accuracy + rtol > expected_accuracy
|
||||
), f"Expected: {expected_accuracy} | Measured: {measured_accuracy}"
|
||||
@@ -1,80 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test attention quantization of gpt-oss model.
|
||||
The qkv_proj and o_proj in self_attention can be either quantized or excluded.
|
||||
|
||||
Run `pytest tests/models/quantization/test_gpt_oss_attn_quantization.py`.
|
||||
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
from dataclasses import dataclass
|
||||
|
||||
import huggingface_hub
|
||||
import lm_eval
|
||||
import pytest
|
||||
from packaging import version
|
||||
|
||||
MODEL_NAMES = ["amd/gpt-oss-20b-customized-attention-quantization"]
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")
|
||||
) >= version.parse("0.8.99")
|
||||
|
||||
|
||||
def has_huggingface_access(repo):
|
||||
try:
|
||||
huggingface_hub.list_repo_refs(repo)
|
||||
return True
|
||||
except huggingface_hub.errors.RepositoryNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
HF_HUB_AMD_ORG_ACCESS = all(
|
||||
[has_huggingface_access(model_name) for model_name in MODEL_NAMES]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCase:
|
||||
model_id: str
|
||||
tp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationConfig:
|
||||
model_name: str
|
||||
|
||||
def get_model_args(self) -> str:
|
||||
return (
|
||||
f"pretrained={self.model_name},"
|
||||
"tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=False"
|
||||
)
|
||||
|
||||
|
||||
EXPECTED_ACCURACIES = {"arc_challenge": 0.20}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.skipif(
|
||||
not HF_HUB_AMD_ORG_ACCESS,
|
||||
reason="Read access to huggingface.co/amd is required for this test.",
|
||||
)
|
||||
@pytest.mark.parametrize("model_name", MODEL_NAMES)
|
||||
@pytest.mark.parametrize("task_name, expected_accuracy", EXPECTED_ACCURACIES.items())
|
||||
def test_gpt_oss_attention_quantization(
|
||||
model_name: str, task_name: str, expected_accuracy: float
|
||||
):
|
||||
measured_accuracy = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=EvaluationConfig(model_name).get_model_args(),
|
||||
tasks=task_name,
|
||||
batch_size="auto",
|
||||
)["results"][task_name]["acc,none"]
|
||||
|
||||
rtol = 0.05
|
||||
assert (
|
||||
measured_accuracy - rtol < expected_accuracy
|
||||
and measured_accuracy + rtol > expected_accuracy
|
||||
), f"Expected: {expected_accuracy} | Measured: {measured_accuracy}"
|
||||
@@ -386,6 +386,10 @@ class FusedMoEQuantConfig:
|
||||
def use_nvfp4_w4a4(self) -> bool:
|
||||
return self.quant_dtype == "nvfp4"
|
||||
|
||||
@property
|
||||
def use_mxfp4_w4a8(self) -> bool:
|
||||
return self._a1.dtype == "fp8" and self._w1.dtype == "mxfp4"
|
||||
|
||||
def config_name(self, dtype: torch.dtype) -> str | None:
|
||||
"""
|
||||
Return a string used to construct the filename that contains the
|
||||
@@ -532,6 +536,8 @@ def fp8_w8a8_moe_quant_config(
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: list[int] | None = None,
|
||||
@@ -549,6 +555,8 @@ def fp8_w8a8_moe_quant_config(
|
||||
g1_alphas=g1_alphas,
|
||||
w2_scale=w2_scale,
|
||||
g2_alphas=g2_alphas,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
a1_scale=a1_scale,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_scale=a2_scale,
|
||||
@@ -564,6 +572,8 @@ def int8_w8a8_moe_quant_config(
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
@@ -575,6 +585,8 @@ def int8_w8a8_moe_quant_config(
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=None,
|
||||
@@ -654,6 +666,26 @@ def mxfp4_mxfp8_moe_quant_config(
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_w4a8_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for fp8 activations and mxfp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc("fp8", None, a1_scale, None, None, None),
|
||||
_a2=FusedMoEQuantDesc("fp8", None, a2_scale, None, None, None),
|
||||
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
|
||||
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
|
||||
)
|
||||
|
||||
|
||||
def ocp_mx_moe_quant_config(
|
||||
quant_dtype: str,
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
@@ -691,6 +723,8 @@ def nvfp4_moe_quant_config(
|
||||
a2_gscale: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and nvp4 weights.
|
||||
@@ -699,6 +733,8 @@ def nvfp4_moe_quant_config(
|
||||
"nvfp4",
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
g1_alphas=g1_alphas,
|
||||
|
||||
@@ -38,7 +38,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
@@ -1583,6 +1582,11 @@ def _get_config_quant_dtype(
|
||||
return "mxfp6_e3m2"
|
||||
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
|
||||
return "mxfp6_e2m3"
|
||||
elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}:
|
||||
return torch.bfloat16
|
||||
elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}:
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -1617,17 +1621,10 @@ def fused_experts_impl(
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
|
||||
elif ocp_mx_scheme is not None:
|
||||
if ocp_mx_scheme in {
|
||||
"w_mxfp4_a_mxfp4",
|
||||
"w_mxfp4_a_mxfp6_e3m2",
|
||||
"w_mxfp4_a_mxfp6_e2m3",
|
||||
}:
|
||||
if ocp_mx_scheme.startswith("w_mxfp4"):
|
||||
# 16bit activation and fp4x2 packed weight
|
||||
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
|
||||
elif ocp_mx_scheme in {
|
||||
"w_mxfp6_e3m2_a_mxfp6_e3m2",
|
||||
"w_mxfp6_e2m3_a_mxfp6_e2m3",
|
||||
}:
|
||||
elif ocp_mx_scheme.startswith("w_mxfp6"):
|
||||
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
|
||||
"hidden size mismatch"
|
||||
)
|
||||
@@ -1717,17 +1714,13 @@ def fused_experts_impl(
|
||||
# TODO: On platforms for which `current_platform.supports_mx()` is True
|
||||
# and for which we have a native OCP mx fused MOE kernel,
|
||||
# this dequantization step should not be done.
|
||||
if ocp_mx_scheme in {
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
|
||||
}:
|
||||
if ocp_mx_scheme.startswith("w_mxfp4"):
|
||||
# Weight has to be dequantized for mxfp4 emulation.
|
||||
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
|
||||
w2_scale = None
|
||||
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
|
||||
elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"):
|
||||
w1 = dequant_mxfp6(
|
||||
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
||||
)
|
||||
@@ -1736,7 +1729,7 @@ def fused_experts_impl(
|
||||
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w2_scale = None
|
||||
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
|
||||
elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"):
|
||||
w1 = dequant_mxfp6(
|
||||
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
|
||||
)
|
||||
@@ -1779,6 +1772,7 @@ def fused_experts_impl(
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
)
|
||||
|
||||
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
|
||||
@@ -1846,6 +1840,7 @@ def fused_experts_impl(
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
|
||||
@@ -221,12 +221,14 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
|
||||
)
|
||||
|
||||
|
||||
# TODO(rob): move this down to the kernel.
|
||||
def maybe_roundup_hidden_size(
|
||||
hidden_size: int,
|
||||
act_dtype: torch.dtype,
|
||||
quant_config: QuantizationConfig | None,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
is_lora_enabled: bool,
|
||||
model_type: str | None,
|
||||
is_mxfp4_quant: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Given layer hidden size and MoE configurations, round up hidden_size
|
||||
@@ -235,11 +237,12 @@ def maybe_roundup_hidden_size(
|
||||
Args:
|
||||
hidden_size: Layer hidden-size
|
||||
act_dtype: Data type of the layer activations.
|
||||
quant_config: Fused MoE quantization configuration.
|
||||
moe_parallel_config: Fused MoE parallelization strategy configuration.
|
||||
is_lora_enabled: True if the engine is enabled with LoRA. This
|
||||
is used in the case of mxfp4 quantization in selecting the
|
||||
MxFP4Backend.
|
||||
model_type: for checking if gpt-oss
|
||||
is_mxfp4_quant: whether the layer is quantized with mxfp4
|
||||
|
||||
Return:
|
||||
Rounded up hidden_size if rounding up is required based on the configs.
|
||||
@@ -254,7 +257,7 @@ def maybe_roundup_hidden_size(
|
||||
)
|
||||
|
||||
# we are padding globally so EP buffer allocation works
|
||||
if quant_config and quant_config.get_name() == "mxfp4":
|
||||
if model_type == "gpt_oss" and is_mxfp4_quant:
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Backend,
|
||||
get_mxfp4_backend,
|
||||
@@ -398,15 +401,6 @@ class FusedMoE(CustomOp):
|
||||
# Expert mapping used in self.load_weights
|
||||
self.expert_mapping = expert_mapping
|
||||
|
||||
# Round up hidden size if needed.
|
||||
hidden_size = maybe_roundup_hidden_size(
|
||||
hidden_size,
|
||||
moe_in_dtype,
|
||||
quant_config,
|
||||
self.moe_parallel_config,
|
||||
is_lora_enabled=self.vllm_config.lora_config is not None,
|
||||
)
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
@@ -508,7 +502,6 @@ class FusedMoE(CustomOp):
|
||||
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
|
||||
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
@@ -548,6 +541,24 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
|
||||
|
||||
# Round up hidden size before creating moe_config.
|
||||
# This way moe_config is created with the correct hidden_size from the start.
|
||||
hidden_size = maybe_roundup_hidden_size(
|
||||
hidden_size=hidden_size,
|
||||
act_dtype=moe_in_dtype,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
is_lora_enabled=vllm_config.lora_config is not None,
|
||||
model_type=(
|
||||
self.vllm_config.model_config.hf_config.model_type
|
||||
if self.vllm_config.model_config is not None
|
||||
else None
|
||||
),
|
||||
is_mxfp4_quant=(
|
||||
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
|
||||
),
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.moe_config: FusedMoEConfig = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
|
||||
@@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp6_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
mxfp8_e4m3_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
per_tensor_dequantize,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
@@ -241,7 +244,27 @@ def moe_kernel_quantize_input(
|
||||
per_act_token_quant: bool,
|
||||
block_shape: list[int] | None = None,
|
||||
is_fp4_scale_swizzled: bool = True,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation
|
||||
if ocp_mx_scheme is not None:
|
||||
if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}:
|
||||
pass # No QDQ needed for these schemes
|
||||
elif ocp_mx_scheme.endswith("a_fp8"):
|
||||
# Perform QDQ (quantize and dequantize) on activation for emulation
|
||||
# purpose, because there is no native kernel for weight in ocp_mx_scheme
|
||||
# and activation in FP8. The implementation is based on existing
|
||||
# non-emulation ops.
|
||||
qA, qA_scale = ops.scaled_fp8_quant(
|
||||
A, A_scale, use_per_token_if_dynamic=False
|
||||
)
|
||||
A = per_tensor_dequantize(qA, qA_scale).to(A.dtype)
|
||||
# After QDQ, we don't need further quantization
|
||||
return A, None
|
||||
# else: For other schemes (e.g., *_a_mxfp6_e3m2, *_a_mxfp6_e2m3),
|
||||
# weights are already dequantized, and we proceed with normal
|
||||
# activation quantization below.
|
||||
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == torch.int8:
|
||||
|
||||
@@ -168,3 +168,19 @@ class QuantizationConfig(ABC):
|
||||
Interface to update values after config initialization.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
|
||||
"""
|
||||
Determine if mxfp4 quantization will be used for this config.
|
||||
|
||||
This allows hidden_size rounding to happen before moe_config creation
|
||||
without needing to instantiate quant_method first.
|
||||
|
||||
Args:
|
||||
prefix: The layer prefix/name in the model
|
||||
layer: The layer module
|
||||
|
||||
Returns:
|
||||
True if this config uses MXFP4 quantization, False otherwise
|
||||
"""
|
||||
return False
|
||||
|
||||
@@ -229,10 +229,15 @@ class Mxfp4Config(QuantizationConfig):
|
||||
)
|
||||
return None
|
||||
|
||||
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
|
||||
"""MXFP4 config always uses MXFP4 quantization."""
|
||||
return True
|
||||
|
||||
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.weight_dtype = "mxfp4"
|
||||
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
|
||||
|
||||
self.marlin_input_dtype = None
|
||||
|
||||
@@ -320,38 +320,45 @@ class QuarkConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||
|
||||
def _is_ocp_mx(
|
||||
self,
|
||||
weight_quant: dict[str, Any] | None,
|
||||
input_quant: dict[str, Any] | None,
|
||||
def _is_w_ocp_mx_a_x(
|
||||
self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None
|
||||
) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
"""
|
||||
This check returns True only if it is an OCP-MX weight quantization.
|
||||
The activation can be any data type (e.g., FP16/BF16, FP8, or OCP-MX format).
|
||||
The rationale for checking only the weight type is that
|
||||
the model loading concept and process primarily concerns the weights themselves.
|
||||
"""
|
||||
# Confirm weights quantized.
|
||||
if weight_quant is None:
|
||||
logger.debug(
|
||||
"Quark model is not in OCP MX format: "
|
||||
"weight_quant or input_quant not set"
|
||||
"Quark model's weight quantization is incompatible with OCP_MX format: "
|
||||
"weight_quant is not set."
|
||||
)
|
||||
return False
|
||||
|
||||
# Input and weight qscheme needs to be per group.
|
||||
if (
|
||||
weight_quant.get("qscheme") != "per_group"
|
||||
or input_quant.get("qscheme") != "per_group"
|
||||
):
|
||||
logger.debug("Quark model is not in OCP MX format: not per_group")
|
||||
if weight_quant.get("qscheme") != "per_group":
|
||||
logger.debug(
|
||||
"Quark model's weight quantization is incompatible with OCP MX format: "
|
||||
"weight is not per_group."
|
||||
)
|
||||
return False
|
||||
|
||||
# Input and weight group size needs to be 32.
|
||||
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
|
||||
logger.debug("Quark model is not in OCP MX format: not group_size=32")
|
||||
if weight_quant.get("group_size") != 32:
|
||||
logger.debug(
|
||||
"Quark model's weight quantization is incompatible with OCP MX format: "
|
||||
"group_size of weight is not 32."
|
||||
)
|
||||
return False
|
||||
|
||||
# Activations and weight scales need to be in e8m0 format.
|
||||
if (
|
||||
weight_quant.get("scale_format") != "e8m0"
|
||||
or input_quant.get("scale_format") != "e8m0"
|
||||
):
|
||||
logger.debug("Quark model is not in OCP MX format: not scale_format e8m0")
|
||||
if weight_quant.get("scale_format") != "e8m0":
|
||||
logger.debug(
|
||||
"Quark model's weight quantization is incompatible with OCP MX format: "
|
||||
"scale_format of weight is not e8m0."
|
||||
)
|
||||
return False
|
||||
|
||||
# Input and weight dtypes need to be any of fp4,
|
||||
@@ -360,14 +367,31 @@ class QuarkConfig(QuantizationConfig):
|
||||
"fp4",
|
||||
"fp6_e3m2",
|
||||
"fp6_e2m3",
|
||||
} or input_quant.get("dtype") not in {"fp4", "fp6_e3m2", "fp6_e2m3"}:
|
||||
}:
|
||||
logger.debug(
|
||||
"Quark model is not in OCP MX format: dtype not fp4, fp6_e3m2, fp6_e2m3"
|
||||
"Quark model's weight quantization is incompatible with OCP MX format: "
|
||||
"dtype is not in {fp4, fp6_e3m2, fp6_e2m3}."
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
|
||||
"""
|
||||
For Quark, determine if it's OCP MXFP4 by checking config directly.
|
||||
This allows hidden_size rounding to happen before moe_config creation.
|
||||
"""
|
||||
layer_quant_config = self._find_matched_config(prefix, layer)
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
return (
|
||||
self._is_w_ocp_mx_a_x(weight_config, input_config)
|
||||
and weight_config is not None
|
||||
and weight_config.get("dtype") == "fp4"
|
||||
and getattr(torch, "float4_e2m1fn_x2", None) is not None
|
||||
)
|
||||
|
||||
def _find_matched_config(
|
||||
self, layer_name: str, module: torch.nn.Module
|
||||
) -> dict[str, Any]:
|
||||
@@ -441,7 +465,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_config.get("symmetric"),
|
||||
)
|
||||
elif self._is_ocp_mx(weight_config, input_config):
|
||||
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX(weight_config, input_config)
|
||||
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
@@ -18,9 +19,15 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
mxfp4_w4a8_moe_quant_config,
|
||||
mxfp4_w4a16_moe_quant_config,
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Backend,
|
||||
get_mxfp4_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_fp8_moe_layer_for_marlin,
|
||||
)
|
||||
@@ -37,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -46,6 +54,7 @@ __all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.has_bias = self.moe.has_bias
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
@@ -67,7 +76,7 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_ocp_mx(weight_config, input_config):
|
||||
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
|
||||
else:
|
||||
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||
@@ -86,6 +95,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
self.weight_qscheme = self.weight_quant.get("qscheme")
|
||||
self.input_qscheme = self.input_quant.get("qscheme")
|
||||
self.weight_dtype = self.weight_quant.get("dtype", "").replace(
|
||||
"fp8_e4m3", "fp8"
|
||||
)
|
||||
self.input_dtype = self.input_quant.get("dtype", "").replace("fp8_e4m3", "fp8")
|
||||
per_tensor = (
|
||||
self.weight_qscheme == "per_tensor" and self.input_qscheme == "per_tensor"
|
||||
)
|
||||
@@ -121,6 +134,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
self.model_type = getattr(
|
||||
get_current_vllm_config().model_config.hf_config, "model_type", None
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -166,9 +183,16 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
if self.weight_qscheme == "per_tensor":
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They are combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
if self.model_type != "gpt_oss":
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
else:
|
||||
# For gpt_oss, the w1(gate) & w3(up) are fused as one.
|
||||
# Therefore, only one weight scale for each expert.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 1, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
@@ -220,6 +244,27 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
if self.has_bias:
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
else:
|
||||
layer.w13_bias, layer.w2_bias = None, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
@@ -278,21 +323,40 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
|
||||
# For gpt_oss, w1 and w3 are fused into a single combined
|
||||
# gate_up_proj tensor with size 2*intermediate_size_per_partition
|
||||
# and only one scale per expert.
|
||||
# Process the entire weight tensor as one shard.
|
||||
if self.model_type == "gpt_oss":
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
# Process all 2*intermediate_size_per_partition rows at once
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
layer.w13_weight[expert_id],
|
||||
layer.w13_weight_scale[expert_id][0],
|
||||
)
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
layer.w13_weight[expert_id], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id]
|
||||
)
|
||||
start += shard_size
|
||||
else:
|
||||
# For non-gpt_oss, process w1 and w3 shards separately
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
|
||||
# quark's scale is 1 dim.
|
||||
elif self.weight_qscheme == "per_channel":
|
||||
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
|
||||
@@ -343,6 +407,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
per_act_token_quant=self.input_qscheme == "per_channel",
|
||||
per_out_ch_quant=self.weight_qscheme == "per_channel",
|
||||
)
|
||||
@@ -563,7 +629,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
input_config: dict[str, Any] | None,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
@@ -571,35 +637,79 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
self.input_quant = input_config
|
||||
|
||||
weight_qscheme = self.weight_quant.get("qscheme")
|
||||
input_qscheme = self.input_quant.get("qscheme")
|
||||
if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
|
||||
if not weight_qscheme == "per_group":
|
||||
raise ValueError(
|
||||
"For MX(FP4) Fused MoE layers, only per-group scales "
|
||||
"for weights and activations are supported. Found "
|
||||
f"{weight_qscheme}, {input_qscheme}"
|
||||
f"for weights are supported. Found {weight_qscheme}."
|
||||
) # noqa E501
|
||||
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
|
||||
self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp")
|
||||
if self.input_quant is not None:
|
||||
input_quant = self.input_quant["dtype"]
|
||||
if input_quant in ["fp4", "fp6_e3m2", "fp6_e2m3"]:
|
||||
self.input_dtype = input_quant.replace("fp", "mxfp")
|
||||
elif input_quant == "fp8_e4m3":
|
||||
self.input_dtype = input_quant.replace("fp8_e4m3", "fp8")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Current input dtype {input_quant} is not compatible \
|
||||
with OCP MX (weight) MoE quantization. Please open an issue"
|
||||
)
|
||||
else:
|
||||
self.input_dtype = None
|
||||
|
||||
self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)
|
||||
|
||||
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
|
||||
self.input_dtype, self.weight_dtype
|
||||
)
|
||||
|
||||
if self.static_input_scales:
|
||||
if self.ocp_mx_scheme is None:
|
||||
raise ValueError(
|
||||
f"Unsupported OCP MX dtype combination for MoE: "
|
||||
f"input_dtype={self.input_dtype}, weight_dtype={self.weight_dtype}. "
|
||||
f"Please check that the combination is supported in OCP_MX_Scheme."
|
||||
)
|
||||
|
||||
self.mxfp4_backend: Mxfp4Backend | None = None
|
||||
if self.ocp_mx_scheme == "w_mxfp4":
|
||||
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
|
||||
|
||||
if self.input_quant is not None:
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
else:
|
||||
self.static_input_scales = False
|
||||
|
||||
if any(
|
||||
self.ocp_mx_scheme.endswith(a_scheme)
|
||||
for a_scheme in ["a_mxfp4", "a_mxfp6_e3m2", "a_mxfp6_e2m3"]
|
||||
):
|
||||
if self.static_input_scales:
|
||||
raise NotImplementedError(
|
||||
"QuarkOCP_MX_MoEMethod with static input scales is currently "
|
||||
f"not implemented for OCP MX scheme {self.ocp_mx_scheme}. "
|
||||
"Please open an issue."
|
||||
)
|
||||
elif self.ocp_mx_scheme.endswith("a_fp8") and not self.static_input_scales:
|
||||
raise NotImplementedError(
|
||||
"QuarkOCP_MX_MoEMethod with static input scales is currently "
|
||||
"not implemented. Please open an issue."
|
||||
"QuarkOCP_MX_MoEMethod with dynamic input scales is currently "
|
||||
f"not implemented for OCP MX scheme {self.ocp_mx_scheme}. "
|
||||
"Please open an issue."
|
||||
)
|
||||
|
||||
self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
self.emulate = not current_platform.supports_mx() or not (
|
||||
self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
|
||||
self.model_type = getattr(
|
||||
get_current_vllm_config().model_config.hf_config, "model_type", None
|
||||
)
|
||||
|
||||
self._emulate = (
|
||||
not current_platform.supports_mx()
|
||||
or not self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
|
||||
|
||||
self.emulate = True if self.model_type == "gpt_oss" else self._emulate
|
||||
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
@@ -640,12 +750,23 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
|
||||
params_dtype = torch.uint8
|
||||
if self.model_type == "gpt_oss":
|
||||
if current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256
|
||||
)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 64
|
||||
)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
self.get_packed_dim(hidden_size, self.weight_dtype),
|
||||
dtype=params_dtype,
|
||||
),
|
||||
@@ -659,7 +780,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
self.get_packed_dim(intermediate_size_per_partition, self.weight_dtype),
|
||||
self.get_packed_dim(
|
||||
intermediate_size_per_partition_after_pad, self.weight_dtype
|
||||
),
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -672,7 +795,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // OCP_MX_BLOCK_SIZE,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
@@ -682,7 +805,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
torch.ones(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
intermediate_size_per_partition_after_pad // OCP_MX_BLOCK_SIZE,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -693,8 +816,96 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
if self.has_bias:
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
else:
|
||||
layer.w13_bias, layer.w2_bias = None, None
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.static_input_scales:
|
||||
# firstly, process activations if fp8 static input
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
||||
layer.w2_input_scale
|
||||
):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer. "
|
||||
)
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
_, _, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
torch.empty_like(layer.w13_weight, dtype=torch.float8_e4m3fnuz),
|
||||
torch.empty_like(
|
||||
layer.w13_weight_scale, dtype=layer.w13_weight_scale.dtype
|
||||
),
|
||||
layer.w13_input_scale,
|
||||
)
|
||||
_, _, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
torch.empty_like(layer.w2_weight, dtype=torch.float8_e4m3fnuz),
|
||||
torch.empty_like(
|
||||
layer.w2_weight_scale, dtype=layer.w13_weight_scale.dtype
|
||||
),
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
# Reset the parameter
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False
|
||||
)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False
|
||||
)
|
||||
|
||||
# secondly, process mxfp weights
|
||||
if self.emulate:
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
@@ -725,15 +936,40 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return ocp_mx_moe_quant_config(
|
||||
quant_dtype=self.input_dtype,
|
||||
weight_dtype=self.weight_dtype,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
)
|
||||
if self.ocp_mx_scheme == "w_mxfp4":
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
)
|
||||
elif self.ocp_mx_scheme == "w_mxfp4_a_fp8":
|
||||
return mxfp4_w4a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
block_shape=None,
|
||||
)
|
||||
elif self.ocp_mx_scheme in ["w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"]:
|
||||
raise NotImplementedError(
|
||||
"Currently there is no corresponding fused moe quant config configured "
|
||||
f"in vLLM for OCP MX scheme {self.ocp_mx_scheme}. Please open an issue."
|
||||
)
|
||||
else:
|
||||
return ocp_mx_moe_quant_config(
|
||||
quant_dtype=self.input_dtype,
|
||||
weight_dtype=self.weight_dtype,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -743,24 +979,34 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if not self.emulate:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
if (
|
||||
self.model_type == "gpt_oss"
|
||||
and self.mxfp4_backend == Mxfp4Backend.TRITON
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Triton kernel implemented fused MoE for GPT_OSS model "
|
||||
"in Quark(MoE) format is not integrated or provided yet."
|
||||
)
|
||||
|
||||
out = rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
out = fused_experts(
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -773,5 +1019,3 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
@@ -20,26 +20,44 @@ SUPPORTED_OCP_MX_DTYPES = {"mxfp4", "mxfp6_e3m2", "mxfp6_e2m3"}
|
||||
|
||||
|
||||
class OCP_MX_Scheme(str, Enum):
|
||||
w_mxfp4 = "w_mxfp4"
|
||||
w_mxfp4_a_mxfp4 = "w_mxfp4_a_mxfp4"
|
||||
w_mxfp4_a_mxfp6_e3m2 = "w_mxfp4_a_mxfp6_e3m2"
|
||||
w_mxfp4_a_mxfp6_e2m3 = "w_mxfp4_a_mxfp6_e2m3"
|
||||
w_mxfp4_a_fp8 = "w_mxfp4_a_fp8"
|
||||
w_mxfp6_e3m2 = "w_mxfp6_e3m2"
|
||||
w_mxfp6_e3m2_a_mxfp6_e3m2 = "w_mxfp6_e3m2_a_mxfp6_e3m2"
|
||||
w_mxfp6_e3m2_a_fp8 = "w_mxfp6_e3m2_a_fp8"
|
||||
w_mxfp6_e2m3 = "w_mxfp6_e2m3"
|
||||
w_mxfp6_e2m3_a_mxfp6_e2m3 = "w_mxfp6_e2m3_a_mxfp6_e2m3"
|
||||
w_mxfp6_e2m3_a_fp8 = "w_mxfp6_e2m3_a_fp8"
|
||||
|
||||
@classmethod
|
||||
def from_quant_dtype(cls, input_dtype: str | None, weight_dtype: str | None):
|
||||
if input_dtype not in OCP_MX_DTYPES or weight_dtype not in OCP_MX_DTYPES:
|
||||
if input_dtype not in OCP_MX_DTYPES and weight_dtype not in OCP_MX_DTYPES:
|
||||
return None
|
||||
elif input_dtype is None and weight_dtype == "mxfp4":
|
||||
return cls.w_mxfp4
|
||||
elif input_dtype is None and weight_dtype == "mxfp6_e3m2":
|
||||
return cls.w_mxfp6_e3m2
|
||||
elif input_dtype is None and weight_dtype == "mxfp6_e2m3":
|
||||
return cls.w_mxfp6_e2m3
|
||||
elif input_dtype == "mxfp4" and weight_dtype == "mxfp4":
|
||||
return cls.w_mxfp4_a_mxfp4
|
||||
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp4":
|
||||
return cls.w_mxfp4_a_mxfp6_e3m2
|
||||
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp4":
|
||||
return cls.w_mxfp4_a_mxfp6_e2m3
|
||||
elif input_dtype == "fp8" and weight_dtype == "mxfp4":
|
||||
return cls.w_mxfp4_a_fp8
|
||||
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp6_e3m2":
|
||||
return cls.w_mxfp6_e3m2_a_mxfp6_e3m2
|
||||
elif input_dtype == "fp8" and weight_dtype == "mxfp6_e3m2":
|
||||
return cls.w_mxfp6_e3m2_a_fp8
|
||||
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp6_e2m3":
|
||||
return cls.w_mxfp6_e2m3_a_mxfp6_e2m3
|
||||
elif input_dtype == "fp8" and weight_dtype == "mxfp6_e2m3":
|
||||
return cls.w_mxfp6_e2m3_a_fp8
|
||||
else:
|
||||
logger.warning(
|
||||
"input_dtype='%s' and"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -25,13 +26,17 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -98,6 +103,7 @@ class OAIAttention(nn.Module):
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_attention_heads,
|
||||
total_num_kv_heads=self.num_key_value_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
@@ -105,6 +111,7 @@ class OAIAttention(nn.Module):
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.num_attention_heads * self.head_dim,
|
||||
output_size=self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
@@ -306,6 +313,19 @@ class GptOssModel(nn.Module):
|
||||
return x, aux_hidden_states
|
||||
return x
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, weight scales, activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
# NOTE: this is only used for quark.
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
self,
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
ckpt_up_proj_name="w3",
|
||||
num_experts=self.config.num_local_experts,
|
||||
num_redundant_experts=0,
|
||||
)
|
||||
|
||||
def _load_weights_mxfp4(
|
||||
self,
|
||||
ep_rank_end: int,
|
||||
@@ -318,7 +338,6 @@ class GptOssModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
mxfp4_block = 32
|
||||
use_ep = self.parallel_config.enable_expert_parallel
|
||||
num_experts = self.config.num_local_experts
|
||||
|
||||
@@ -333,9 +352,11 @@ class GptOssModel(nn.Module):
|
||||
)
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
|
||||
per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
|
||||
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
||||
per_rank_intermediate_size = (
|
||||
per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE
|
||||
)
|
||||
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
@@ -370,7 +391,9 @@ class GptOssModel(nn.Module):
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[
|
||||
..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
|
||||
...,
|
||||
tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end
|
||||
// OCP_MX_BLOCK_SIZE,
|
||||
]
|
||||
|
||||
param = params_dict[name]
|
||||
@@ -495,6 +518,449 @@ class GptOssModel(nn.Module):
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _load_weights_quark(
|
||||
self,
|
||||
ep_rank_end: int,
|
||||
ep_rank_start: int,
|
||||
heads_per_rank: int,
|
||||
head_start: int,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
stacked_params_mapping: list[tuple[str, ...]],
|
||||
) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
use_ep = self.parallel_config.enable_expert_parallel
|
||||
num_experts = self.config.num_local_experts
|
||||
|
||||
if use_ep:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
dp_size=get_dp_group().world_size,
|
||||
dp_rank=get_dp_group().rank_in_group,
|
||||
pcp_size=get_pcp_group().world_size,
|
||||
pcp_rank=get_pcp_group().rank_in_group,
|
||||
)
|
||||
|
||||
def _get_moe_weight_dtype(layer_id: int = 0) -> str | None:
|
||||
"""Helper function to get MoE quantization weight dtype.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index to check (default 0, as all layers should
|
||||
have the same quantization method)
|
||||
|
||||
Returns:
|
||||
Weight dtype string (e.g., "mxfp4", "fp8") or None if not available
|
||||
"""
|
||||
if hasattr(self.layers[layer_id].mlp.experts.quant_method, "weight_dtype"):
|
||||
return self.layers[layer_id].mlp.experts.quant_method.weight_dtype
|
||||
return None
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
|
||||
moe_weight_dtype = _get_moe_weight_dtype(layer_id=0)
|
||||
|
||||
if moe_weight_dtype == "mxfp4":
|
||||
# MXFP4 requires OCP_MX_BLOCK_SIZE alignment
|
||||
intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
|
||||
per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
|
||||
per_rank_intermediate_size = (
|
||||
per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE
|
||||
)
|
||||
else:
|
||||
# FP8 and other formats don't need alignment
|
||||
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
||||
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
layer_id, expert_id, fused_name = None, None, None
|
||||
moe_quant_method = None
|
||||
if "experts" in name:
|
||||
parts = name.split(".")
|
||||
ids = [s for s in parts if s.isdigit()]
|
||||
|
||||
# for amd-quark format that each expert is seperated
|
||||
# need to extract the parameter name with experts fused.
|
||||
# example model: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
|
||||
if len(ids) == 2:
|
||||
layer_id, expert_id = int(ids[0]), int(ids[-1])
|
||||
parts.pop(len(parts) - 1 - parts[::-1].index(str(expert_id)))
|
||||
fused_name = ".".join(parts)
|
||||
|
||||
# for openai mxfp4 format that all experts are combined
|
||||
# no need to extract the parameter name with experts fused.
|
||||
# models: openai/gpt-oss-20b, openai/gpt-oss-120b
|
||||
elif len(ids) == 1:
|
||||
layer_id, expert_id = int(ids[0]), None
|
||||
fused_name = name
|
||||
|
||||
else:
|
||||
raise NameError(
|
||||
f"Layer {name} contains more than 2 numeric indices. This is "
|
||||
"an unexpected condition. Please open an issue if encountered."
|
||||
)
|
||||
|
||||
moe_quant_method = _get_moe_weight_dtype(layer_id=layer_id)
|
||||
|
||||
def kv_cache_scale_loader(
|
||||
quant_config: QuantizationConfig,
|
||||
name: str,
|
||||
params_dict: dict[str, typing.Any],
|
||||
weight: torch.Tensor,
|
||||
default_weight_loader: Callable[..., None],
|
||||
loaded_params: set[str],
|
||||
) -> tuple[bool, set[str]]:
|
||||
"""
|
||||
Load KV cache output scales.
|
||||
Returns:
|
||||
Tuple of (bool, set):
|
||||
- bool: True if KV-cache scale was loaded into loaded_params
|
||||
- set: Updated set of loaded_params if True else the original set
|
||||
"""
|
||||
# load explicit cached KV output scale from quant_config
|
||||
if quant_config is not None and (
|
||||
scale_name := quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
if weight.numel() != 1:
|
||||
raise ValueError(
|
||||
f"KV cache scale '{scale_name}' is expected to be a "
|
||||
f"scalar, but got a tensor of shape {weight.shape}."
|
||||
)
|
||||
# Ensure weight is a scalar before passing to loader.
|
||||
weight_loader(param, weight.flatten()[0])
|
||||
loaded_params.add(scale_name)
|
||||
return True, loaded_params
|
||||
|
||||
return False, loaded_params
|
||||
|
||||
load_kv_cache_scale_completed, loaded_params = kv_cache_scale_loader(
|
||||
self.quant_config,
|
||||
name,
|
||||
params_dict,
|
||||
loaded_weight,
|
||||
default_weight_loader,
|
||||
loaded_params,
|
||||
)
|
||||
if load_kv_cache_scale_completed:
|
||||
continue
|
||||
|
||||
if (
|
||||
all(key in name for key in ["input_scale", "mlp.experts"])
|
||||
and expert_id is not None
|
||||
):
|
||||
assert loaded_weight.numel() == 1
|
||||
expert_data = params_dict[fused_name].data[expert_id]
|
||||
expert_data.copy_(loaded_weight)
|
||||
loaded_params.add(fused_name)
|
||||
continue
|
||||
|
||||
# Unified handler for mxfp4 weights and scales
|
||||
elif moe_quant_method == "mxfp4" and any(
|
||||
name.endswith(suffix)
|
||||
for suffix in [
|
||||
".w13_weight_scale",
|
||||
".w2_weight_scale",
|
||||
".w13_weight",
|
||||
".w2_weight",
|
||||
]
|
||||
):
|
||||
is_w13 = ".w13_" in name
|
||||
is_scale = "_scale" in name
|
||||
|
||||
# Reshape weight for mxfp4 if needed (not for scales)
|
||||
if not is_scale and expert_id is None:
|
||||
if is_w13:
|
||||
if loaded_weight.dim() < 3:
|
||||
raise ValueError(
|
||||
f"Expected w13_weight to have at least 3 "
|
||||
f"dimensions, got shape "
|
||||
f"{loaded_weight.shape}"
|
||||
)
|
||||
if loaded_weight.shape[0] != num_experts:
|
||||
raise ValueError(
|
||||
f"Expected w13_weight first dimension to be "
|
||||
f"{num_experts}, got "
|
||||
f"{loaded_weight.shape[0]}"
|
||||
)
|
||||
loaded_weight = loaded_weight.view(
|
||||
num_experts, 2 * intermediate_size, -1
|
||||
).contiguous()
|
||||
else:
|
||||
if loaded_weight.dim() < 3:
|
||||
raise ValueError(
|
||||
f"Expected w2_weight to have at least 3 "
|
||||
f"dimensions, got shape "
|
||||
f"{loaded_weight.shape}"
|
||||
)
|
||||
if loaded_weight.shape[0] != num_experts:
|
||||
raise ValueError(
|
||||
f"Expected w2_weight first dimension to be "
|
||||
f"{num_experts}, got "
|
||||
f"{loaded_weight.shape[0]}"
|
||||
)
|
||||
loaded_weight = loaded_weight.view(
|
||||
num_experts, -1, intermediate_size // 2
|
||||
).contiguous()
|
||||
|
||||
if use_ep:
|
||||
sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
if is_w13:
|
||||
if expert_id is None:
|
||||
sliced_weight = loaded_weight[
|
||||
:, 2 * tp_rank_start : 2 * tp_rank_end, ...
|
||||
]
|
||||
else:
|
||||
sliced_weight = loaded_weight[
|
||||
2 * tp_rank_start : 2 * tp_rank_end, ...
|
||||
]
|
||||
else:
|
||||
if is_scale:
|
||||
sliced_weight = loaded_weight[
|
||||
...,
|
||||
tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end
|
||||
// OCP_MX_BLOCK_SIZE,
|
||||
]
|
||||
else:
|
||||
sliced_weight = loaded_weight[
|
||||
..., tp_rank_start // 2 : tp_rank_end // 2
|
||||
]
|
||||
|
||||
# NOTE(rob): because gpt-oss ckpt has "unique" structure with
|
||||
# fused gate_up_proj fused on disk, we cannot use the existing
|
||||
# weight loaders without added complexity, so just do the
|
||||
# direct load here.
|
||||
param = params_dict[fused_name]
|
||||
expert_data = param.data[expert_id]
|
||||
dim1 = sliced_weight.shape[0]
|
||||
dim2 = sliced_weight.shape[1]
|
||||
expert_data.data[:dim1, :dim2].copy_(sliced_weight)
|
||||
loaded_params.add(fused_name)
|
||||
continue
|
||||
|
||||
elif name.endswith(".w13_weight") and moe_quant_method == "fp8":
|
||||
if use_ep:
|
||||
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
if expert_id is None:
|
||||
narrow_weight = loaded_weight[
|
||||
:, 2 * tp_rank_start : 2 * tp_rank_end, :
|
||||
]
|
||||
else:
|
||||
narrow_weight = loaded_weight[
|
||||
2 * tp_rank_start : 2 * tp_rank_end, :
|
||||
]
|
||||
|
||||
assert fused_name is not None
|
||||
param = params_dict[fused_name]
|
||||
|
||||
if expert_id is None:
|
||||
param.data.copy_(narrow_weight)
|
||||
else:
|
||||
param.data[expert_id].copy_(narrow_weight)
|
||||
|
||||
loaded_params.add(fused_name)
|
||||
continue
|
||||
|
||||
elif name.endswith(".w13_weight_scale") and moe_quant_method == "fp8":
|
||||
assert fused_name is not None
|
||||
param = params_dict[fused_name]
|
||||
|
||||
# Check if this is per-channel or per-tensor scale
|
||||
if loaded_weight.numel() > 1 and loaded_weight.dim() == 1:
|
||||
if use_ep:
|
||||
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = loaded_weight[
|
||||
2 * tp_rank_start : 2 * tp_rank_end
|
||||
]
|
||||
else:
|
||||
narrow_weight = loaded_weight
|
||||
|
||||
if expert_id is None:
|
||||
param.data.copy_(narrow_weight)
|
||||
else:
|
||||
param.data[expert_id].copy_(narrow_weight)
|
||||
|
||||
loaded_params.add(fused_name)
|
||||
continue
|
||||
|
||||
elif name.endswith(".w13_input_scale") and moe_quant_method == "fp8":
|
||||
assert fused_name is not None
|
||||
param = params_dict[fused_name]
|
||||
|
||||
if expert_id is None:
|
||||
param.data.copy_(loaded_weight)
|
||||
else:
|
||||
param.data[expert_id].copy_(loaded_weight)
|
||||
|
||||
loaded_params.add(fused_name)
|
||||
continue
|
||||
|
||||
elif name.endswith(".w2_weight") and moe_quant_method == "fp8":
|
||||
if use_ep:
|
||||
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
if expert_id is None:
|
||||
narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end]
|
||||
else:
|
||||
narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end]
|
||||
|
||||
assert fused_name is not None
|
||||
param = params_dict[fused_name]
|
||||
|
||||
if expert_id is None:
|
||||
param.data.copy_(narrow_weight)
|
||||
else:
|
||||
param.data[expert_id].copy_(narrow_weight)
|
||||
|
||||
loaded_params.add(fused_name)
|
||||
continue
|
||||
|
||||
elif name.endswith(".w2_weight_scale") and moe_quant_method == "fp8":
|
||||
assert fused_name is not None
|
||||
param = params_dict[fused_name]
|
||||
|
||||
if use_ep:
|
||||
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = loaded_weight
|
||||
|
||||
if expert_id is None:
|
||||
param.data.copy_(narrow_weight)
|
||||
else:
|
||||
param.data[expert_id].copy_(narrow_weight)
|
||||
|
||||
loaded_params.add(fused_name)
|
||||
continue
|
||||
|
||||
# Unified handler for bias loading (w13_bias and w2_bias)
|
||||
elif name.endswith(".w13_bias") or name.endswith(".w2_bias"):
|
||||
is_w13_bias = name.endswith(".w13_bias")
|
||||
|
||||
if use_ep:
|
||||
sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
if is_w13_bias:
|
||||
if expert_id is None:
|
||||
sliced_weight = loaded_weight[
|
||||
:, 2 * tp_rank_start : 2 * tp_rank_end
|
||||
]
|
||||
else:
|
||||
sliced_weight = loaded_weight[
|
||||
2 * tp_rank_start : 2 * tp_rank_end
|
||||
]
|
||||
else:
|
||||
sliced_weight = loaded_weight
|
||||
if tp_rank != 0:
|
||||
sliced_weight = sliced_weight.zero_()
|
||||
|
||||
# NOTE(rob): because gpt-oss ckpt has "unique" structure with
|
||||
# fused gate_up_proj fused on disk, we cannot use the existing
|
||||
# weight loaders without added complexity, so just do the
|
||||
# direct load here.
|
||||
assert fused_name is not None
|
||||
param = params_dict[fused_name]
|
||||
expert_data = param.data[expert_id]
|
||||
dim1 = sliced_weight.shape[0]
|
||||
expert_data.data[:dim1].copy_(sliced_weight)
|
||||
loaded_params.add(fused_name)
|
||||
continue
|
||||
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
param = params_dict[name]
|
||||
narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if name.endswith("scale"):
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
param_name, weight_name, mapping_expert_id, shard_id = mapping
|
||||
weight_name = (
|
||||
weight_name[:-1] if weight_name.endswith(".") else weight_name
|
||||
)
|
||||
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
param = params_dict[fused_name]
|
||||
# We should ask the weight loader to return success or not
|
||||
# here since otherwise we may skip experts with other
|
||||
# available replicas.
|
||||
weight_loader = typing.cast(
|
||||
Callable[..., bool], param.weight_loader
|
||||
)
|
||||
# Use checkpoint's expert_id for quark format (when expert_id
|
||||
# is extracted from weight name), otherwise use mapping's expert_id
|
||||
actual_expert_id = (
|
||||
expert_id if expert_id is not None else mapping_expert_id
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
fused_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=actual_expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
name = fused_name
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _load_weights_other(
|
||||
self,
|
||||
ep_rank_end: int,
|
||||
@@ -635,6 +1101,7 @@ class GptOssModel(nn.Module):
|
||||
if hasattr(self.config, "quantization_config")
|
||||
else None
|
||||
)
|
||||
|
||||
if quant_method == "mxfp4":
|
||||
return self._load_weights_mxfp4(
|
||||
ep_rank_end,
|
||||
@@ -644,6 +1111,15 @@ class GptOssModel(nn.Module):
|
||||
weights,
|
||||
stacked_params_mapping,
|
||||
)
|
||||
elif quant_method == "quark":
|
||||
return self._load_weights_quark(
|
||||
ep_rank_end,
|
||||
ep_rank_start,
|
||||
heads_per_rank,
|
||||
head_start,
|
||||
weights,
|
||||
stacked_params_mapping,
|
||||
)
|
||||
else:
|
||||
return self._load_weights_other(
|
||||
ep_rank_end,
|
||||
@@ -676,6 +1152,15 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
||||
# MoE Bias
|
||||
".gate_up_proj_bias": ".w13_bias",
|
||||
".down_proj_bias": ".w2_bias",
|
||||
# For quark format
|
||||
".gate_up_proj.weight": ".w13_weight",
|
||||
".gate_up_proj.weight_scale": ".w13_weight_scale",
|
||||
".gate_up_proj.bias": ".w13_bias",
|
||||
".gate_up_proj.input_scale": ".w13_input_scale",
|
||||
".down_proj.weight": ".w2_weight",
|
||||
".down_proj.weight_scale": ".w2_weight_scale",
|
||||
".down_proj.bias": ".w2_bias",
|
||||
".down_proj.input_scale": ".w2_input_scale",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -725,18 +1210,6 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, weight scales, activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
self,
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_local_experts,
|
||||
num_redundant_experts=0,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user