[Perf] fused_moe: add int4_w4a16 benchmark support and tuning config (#34130)

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
Matthias Gehre
2026-02-13 09:14:27 +01:00
committed by GitHub
parent 742d214d6e
commit 934acddef9
2 changed files with 185 additions and 8 deletions

View File

@@ -100,13 +100,38 @@ def benchmark_config(
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool = False,
num_iters: int = 100,
block_quant_shape: list[int] = None,
use_deep_gemm: bool = False,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16:
if use_int4_w4a16:
# Int4 packed weights: 2 int4 values per uint8 byte
# K dimension is packed (halved)
intermediate_size = shard_intermediate_size // 2 # after silu_and_mul
w1 = torch.randint(
0,
255,
(
num_experts,
shard_intermediate_size,
hidden_size // 2, # int4 packing
),
dtype=torch.uint8,
)
w2 = torch.randint(
0,
255,
(
num_experts,
hidden_size,
intermediate_size // 2, # int4 packing
),
dtype=torch.uint8,
)
elif use_int8_w8a16:
w1 = torch.randint(
-127,
127,
@@ -140,7 +165,20 @@ def benchmark_config(
w2_scale = None
a1_scale = None
a2_scale = None
if use_int8_w8a16:
if use_int4_w4a16:
if block_quant_shape is None:
raise ValueError("block_quant_shape is required for int4_w4a16")
group_size = block_quant_shape[1]
# Scales shape: (E, N, K // group_size) in fp16
w1_scale = torch.rand(
(num_experts, shard_intermediate_size, hidden_size // group_size),
dtype=dtype,
)
w2_scale = torch.rand(
(num_experts, hidden_size, intermediate_size // group_size),
dtype=dtype,
)
elif use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
@@ -199,6 +237,7 @@ def benchmark_config(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
weight_dtype="int4" if use_int4_w4a16 else None,
)
deep_gemm_experts = None
@@ -481,6 +520,7 @@ class BenchmarkWorker:
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool = False,
block_quant_shape: list[int] = None,
use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]:
@@ -488,7 +528,10 @@ class BenchmarkWorker:
set_random_seed(self.seed)
dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int4_w4a16=use_int4_w4a16,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
@@ -519,6 +562,7 @@ class BenchmarkWorker:
dtype,
use_fp8_w8a8,
use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
num_iters=100,
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm,
@@ -535,6 +579,7 @@ class BenchmarkWorker:
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
search_space: list[dict[str, int]],
block_quant_shape: list[int],
use_deep_gemm: bool,
@@ -545,7 +590,7 @@ class BenchmarkWorker:
best_config = None
best_time = float("inf")
if current_platform.is_rocm():
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16)
search_space = prune_rocm_search_space(
num_tokens,
shard_intermediate_size,
@@ -574,6 +619,7 @@ class BenchmarkWorker:
dtype,
use_fp8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
num_iters=20,
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm,
@@ -621,6 +667,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
else {}
),
**({"kpack": config["kpack"]} if "kpack" in config else {}),
**({"SPLIT_K": config["SPLIT_K"]} if "SPLIT_K" in config else {}),
}
@@ -633,11 +680,15 @@ def save_configs(
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
block_quant_shape: list[int],
save_dir: str,
) -> None:
dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int4_w4a16=use_int4_w4a16,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
@@ -739,6 +790,38 @@ def get_model_params(config):
return E, topk, intermediate_size, hidden_size
def get_quantization_group_size(config) -> int | None:
"""Extract the quantization group size from the HF model config.
This reads directly from the HuggingFace config object (as returned by
``get_config()``), not from vLLM's quantization config classes.
Supports AWQ/GPTQ-style configs (direct 'group_size' key) and
compressed-tensors configs (nested inside 'config_groups').
"""
quantization_config = getattr(config, "quantization_config", {})
if not isinstance(quantization_config, dict):
return None
# AWQ / GPTQ style: group_size is a top-level key
gs = quantization_config.get("group_size")
if gs is not None:
return gs
# compressed-tensors style: group_size is nested in config_groups
config_groups = quantization_config.get("config_groups", {})
if not isinstance(config_groups, dict):
return None
for group_cfg in config_groups.values():
if not isinstance(group_cfg, dict):
continue
weights = group_cfg.get("weights", {})
if not isinstance(weights, dict):
continue
gs = weights.get("group_size")
if gs is not None:
return gs
return None
def main(args: argparse.Namespace):
print(args)
@@ -757,7 +840,20 @@ def main(args: argparse.Namespace):
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
use_int4_w4a16 = args.dtype == "int4_w4a16"
block_quant_shape = get_weight_block_size_safety(config)
if use_int4_w4a16:
group_size = get_quantization_group_size(config)
if group_size is None:
raise ValueError(
"Could not determine group_size from model config. "
"The model's quantization_config must contain a 'group_size' "
"field (AWQ/GPTQ) or 'config_groups.*.weights.group_size' "
"(compressed-tensors)."
)
# For int4_w4a16, block_shape = [0, group_size]
# block_shape[0]=0 means no block quantization on N dimension
block_quant_shape = [0, group_size]
if args.batch_size is None:
batch_sizes = [
@@ -811,8 +907,20 @@ def main(args: argparse.Namespace):
return ray.get(outputs)
if args.tune:
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
# int4_w4a16 weights are uint8-packed, not fp16; treat like fp8 for
# search space generation (no matrix_instr_nonkdim/kpack exploration).
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16)
# For int4_w4a16, the group_size constraint on BLOCK_SIZE_K does not
# apply: the gptq_awq kernel handles arbitrary BLOCK_SIZE_K regardless
# of group_size. Skip block_quant_shape filtering to keep the full
# search space (e.g. BLOCK_SIZE_K=64 with group_size=128).
tune_block_quant_shape = None if use_int4_w4a16 else block_quant_shape
search_space = get_configs_compute_bound(is_fp16, tune_block_quant_shape)
if use_int4_w4a16:
# SPLIT_K is a required kernel constexpr for gptq_awq kernel;
# only SPLIT_K=1 is used at runtime, so fix it during tuning.
for cfg in search_space:
cfg["SPLIT_K"] = 1
print(f"Start tuning over {len(search_space)} configurations...")
if use_deep_gemm:
raise ValueError(
@@ -832,6 +940,7 @@ def main(args: argparse.Namespace):
dtype,
use_fp8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
search_space,
block_quant_shape,
use_deep_gemm,
@@ -851,6 +960,7 @@ def main(args: argparse.Namespace):
dtype,
use_fp8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
block_quant_shape,
args.save_dir,
)
@@ -869,6 +979,7 @@ def main(args: argparse.Namespace):
dtype,
use_fp8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
block_quant_shape,
use_deep_gemm,
)
@@ -891,7 +1002,10 @@ if __name__ == "__main__":
)
parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
"--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16", "int4_w4a16"],
default="auto",
)
parser.add_argument("--use-deep-gemm", action="store_true")
parser.add_argument(