[ROCm][Quantization] GPT_OSS in amd-quark format model loading and emulations (#29008)
Signed-off-by: xuebwang-amd <xuebwang@amd.com> Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
@@ -298,12 +298,18 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
|
||||
pc2,
|
||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_scale=pc1,
|
||||
w2_scale=pc2,
|
||||
)
|
||||
if a_dtype == "bf16" and w_dtype == "mx4":
|
||||
quant_config = mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale=pc1,
|
||||
w2_scale=pc2,
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Quantization configuration for activation={a_dtype} and weight={w_dtype} "
|
||||
f"has not been implemented."
|
||||
)
|
||||
|
||||
out_triton_monolithic = triton_kernel_moe_forward(
|
||||
hidden_states=x_tri,
|
||||
|
||||
Reference in New Issue
Block a user