[ROCm][Quantization] GPT_OSS in amd-quark format model loading and emulations (#29008)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
xuebwang-amd
2026-02-10 23:08:05 +08:00
committed by GitHub
parent 599e4335a4
commit b129136c7a
13 changed files with 1094 additions and 213 deletions

View File

@@ -22,7 +22,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
triton_kernel_moe_forward,
)
@@ -298,12 +298,18 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
pc2,
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
quant_config = FusedMoEQuantConfig.make(
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
w1_scale=pc1,
w2_scale=pc2,
)
if a_dtype == "bf16" and w_dtype == "mx4":
quant_config = mxfp4_w4a16_moe_quant_config(
w1_scale=pc1,
w2_scale=pc2,
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
)
else:
raise NotImplementedError(
f"Quantization configuration for activation={a_dtype} and weight={w_dtype} "
f"has not been implemented."
)
out_triton_monolithic = triton_kernel_moe_forward(
hidden_states=x_tri,