[Perf] fused_moe: add int4_w4a16 benchmark support and tuning config (#34130)
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -100,13 +100,38 @@ def benchmark_config(
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool = False,
|
||||
num_iters: int = 100,
|
||||
block_quant_shape: list[int] = None,
|
||||
use_deep_gemm: bool = False,
|
||||
) -> float:
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
if use_int8_w8a16:
|
||||
if use_int4_w4a16:
|
||||
# Int4 packed weights: 2 int4 values per uint8 byte
|
||||
# K dimension is packed (halved)
|
||||
intermediate_size = shard_intermediate_size // 2 # after silu_and_mul
|
||||
w1 = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size // 2, # int4 packing
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
w2 = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size // 2, # int4 packing
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
elif use_int8_w8a16:
|
||||
w1 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
@@ -140,7 +165,20 @@ def benchmark_config(
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
if use_int8_w8a16:
|
||||
if use_int4_w4a16:
|
||||
if block_quant_shape is None:
|
||||
raise ValueError("block_quant_shape is required for int4_w4a16")
|
||||
group_size = block_quant_shape[1]
|
||||
# Scales shape: (E, N, K // group_size) in fp16
|
||||
w1_scale = torch.rand(
|
||||
(num_experts, shard_intermediate_size, hidden_size // group_size),
|
||||
dtype=dtype,
|
||||
)
|
||||
w2_scale = torch.rand(
|
||||
(num_experts, hidden_size, intermediate_size // group_size),
|
||||
dtype=dtype,
|
||||
)
|
||||
elif use_int8_w8a16:
|
||||
w1_scale = torch.randn(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
@@ -199,6 +237,7 @@ def benchmark_config(
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
weight_dtype="int4" if use_int4_w4a16 else None,
|
||||
)
|
||||
|
||||
deep_gemm_experts = None
|
||||
@@ -481,6 +520,7 @@ class BenchmarkWorker:
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool = False,
|
||||
block_quant_shape: list[int] = None,
|
||||
use_deep_gemm: bool = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
@@ -488,7 +528,10 @@ class BenchmarkWorker:
|
||||
|
||||
set_random_seed(self.seed)
|
||||
dtype_str = _get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
@@ -519,6 +562,7 @@ class BenchmarkWorker:
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
num_iters=100,
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm,
|
||||
@@ -535,6 +579,7 @@ class BenchmarkWorker:
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
search_space: list[dict[str, int]],
|
||||
block_quant_shape: list[int],
|
||||
use_deep_gemm: bool,
|
||||
@@ -545,7 +590,7 @@ class BenchmarkWorker:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
if current_platform.is_rocm():
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16)
|
||||
search_space = prune_rocm_search_space(
|
||||
num_tokens,
|
||||
shard_intermediate_size,
|
||||
@@ -574,6 +619,7 @@ class BenchmarkWorker:
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
num_iters=20,
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm,
|
||||
@@ -621,6 +667,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
else {}
|
||||
),
|
||||
**({"kpack": config["kpack"]} if "kpack" in config else {}),
|
||||
**({"SPLIT_K": config["SPLIT_K"]} if "SPLIT_K" in config else {}),
|
||||
}
|
||||
|
||||
|
||||
@@ -633,11 +680,15 @@ def save_configs(
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
block_quant_shape: list[int],
|
||||
save_dir: str,
|
||||
) -> None:
|
||||
dtype_str = _get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
@@ -739,6 +790,38 @@ def get_model_params(config):
|
||||
return E, topk, intermediate_size, hidden_size
|
||||
|
||||
|
||||
def get_quantization_group_size(config) -> int | None:
|
||||
"""Extract the quantization group size from the HF model config.
|
||||
|
||||
This reads directly from the HuggingFace config object (as returned by
|
||||
``get_config()``), not from vLLM's quantization config classes.
|
||||
|
||||
Supports AWQ/GPTQ-style configs (direct 'group_size' key) and
|
||||
compressed-tensors configs (nested inside 'config_groups').
|
||||
"""
|
||||
quantization_config = getattr(config, "quantization_config", {})
|
||||
if not isinstance(quantization_config, dict):
|
||||
return None
|
||||
# AWQ / GPTQ style: group_size is a top-level key
|
||||
gs = quantization_config.get("group_size")
|
||||
if gs is not None:
|
||||
return gs
|
||||
# compressed-tensors style: group_size is nested in config_groups
|
||||
config_groups = quantization_config.get("config_groups", {})
|
||||
if not isinstance(config_groups, dict):
|
||||
return None
|
||||
for group_cfg in config_groups.values():
|
||||
if not isinstance(group_cfg, dict):
|
||||
continue
|
||||
weights = group_cfg.get("weights", {})
|
||||
if not isinstance(weights, dict):
|
||||
continue
|
||||
gs = weights.get("group_size")
|
||||
if gs is not None:
|
||||
return gs
|
||||
return None
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
@@ -757,7 +840,20 @@ def main(args: argparse.Namespace):
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
use_int4_w4a16 = args.dtype == "int4_w4a16"
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
if use_int4_w4a16:
|
||||
group_size = get_quantization_group_size(config)
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
"Could not determine group_size from model config. "
|
||||
"The model's quantization_config must contain a 'group_size' "
|
||||
"field (AWQ/GPTQ) or 'config_groups.*.weights.group_size' "
|
||||
"(compressed-tensors)."
|
||||
)
|
||||
# For int4_w4a16, block_shape = [0, group_size]
|
||||
# block_shape[0]=0 means no block quantization on N dimension
|
||||
block_quant_shape = [0, group_size]
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
@@ -811,8 +907,20 @@ def main(args: argparse.Namespace):
|
||||
return ray.get(outputs)
|
||||
|
||||
if args.tune:
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
|
||||
# int4_w4a16 weights are uint8-packed, not fp16; treat like fp8 for
|
||||
# search space generation (no matrix_instr_nonkdim/kpack exploration).
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16)
|
||||
# For int4_w4a16, the group_size constraint on BLOCK_SIZE_K does not
|
||||
# apply: the gptq_awq kernel handles arbitrary BLOCK_SIZE_K regardless
|
||||
# of group_size. Skip block_quant_shape filtering to keep the full
|
||||
# search space (e.g. BLOCK_SIZE_K=64 with group_size=128).
|
||||
tune_block_quant_shape = None if use_int4_w4a16 else block_quant_shape
|
||||
search_space = get_configs_compute_bound(is_fp16, tune_block_quant_shape)
|
||||
if use_int4_w4a16:
|
||||
# SPLIT_K is a required kernel constexpr for gptq_awq kernel;
|
||||
# only SPLIT_K=1 is used at runtime, so fix it during tuning.
|
||||
for cfg in search_space:
|
||||
cfg["SPLIT_K"] = 1
|
||||
print(f"Start tuning over {len(search_space)} configurations...")
|
||||
if use_deep_gemm:
|
||||
raise ValueError(
|
||||
@@ -832,6 +940,7 @@ def main(args: argparse.Namespace):
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
search_space,
|
||||
block_quant_shape,
|
||||
use_deep_gemm,
|
||||
@@ -851,6 +960,7 @@ def main(args: argparse.Namespace):
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
block_quant_shape,
|
||||
args.save_dir,
|
||||
)
|
||||
@@ -869,6 +979,7 @@ def main(args: argparse.Namespace):
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
block_quant_shape,
|
||||
use_deep_gemm,
|
||||
)
|
||||
@@ -891,7 +1002,10 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16", "int4_w4a16"],
|
||||
default="auto",
|
||||
)
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user