diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 5ee1cf199..e086a109f 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -100,13 +100,38 @@ def benchmark_config( dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool = False, num_iters: int = 100, block_quant_shape: list[int] = None, use_deep_gemm: bool = False, ) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) - if use_int8_w8a16: + if use_int4_w4a16: + # Int4 packed weights: 2 int4 values per uint8 byte + # K dimension is packed (halved) + intermediate_size = shard_intermediate_size // 2 # after silu_and_mul + w1 = torch.randint( + 0, + 255, + ( + num_experts, + shard_intermediate_size, + hidden_size // 2, # int4 packing + ), + dtype=torch.uint8, + ) + w2 = torch.randint( + 0, + 255, + ( + num_experts, + hidden_size, + intermediate_size // 2, # int4 packing + ), + dtype=torch.uint8, + ) + elif use_int8_w8a16: w1 = torch.randint( -127, 127, @@ -140,7 +165,20 @@ def benchmark_config( w2_scale = None a1_scale = None a2_scale = None - if use_int8_w8a16: + if use_int4_w4a16: + if block_quant_shape is None: + raise ValueError("block_quant_shape is required for int4_w4a16") + group_size = block_quant_shape[1] + # Scales shape: (E, N, K // group_size) in fp16 + w1_scale = torch.rand( + (num_experts, shard_intermediate_size, hidden_size // group_size), + dtype=dtype, + ) + w2_scale = torch.rand( + (num_experts, hidden_size, intermediate_size // group_size), + dtype=dtype, + ) + elif use_int8_w8a16: w1_scale = torch.randn( (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 ) @@ -199,6 +237,7 @@ def benchmark_config( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_quant_shape, + weight_dtype="int4" if use_int4_w4a16 else None, ) deep_gemm_experts = None @@ -481,6 +520,7 @@ class BenchmarkWorker: dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool = False, block_quant_shape: list[int] = None, use_deep_gemm: bool = False, ) -> tuple[dict[str, int], float]: @@ -488,7 +528,10 @@ class BenchmarkWorker: set_random_seed(self.seed) dtype_str = _get_config_dtype_str( - dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + dtype, + use_int8_w8a16=use_int8_w8a16, + use_fp8_w8a8=use_fp8_w8a8, + use_int4_w4a16=use_int4_w4a16, ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. @@ -519,6 +562,7 @@ class BenchmarkWorker: dtype, use_fp8_w8a8, use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, num_iters=100, block_quant_shape=block_quant_shape, use_deep_gemm=use_deep_gemm, @@ -535,6 +579,7 @@ class BenchmarkWorker: dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, search_space: list[dict[str, int]], block_quant_shape: list[int], use_deep_gemm: bool, @@ -545,7 +590,7 @@ class BenchmarkWorker: best_config = None best_time = float("inf") if current_platform.is_rocm(): - is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) + is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16) search_space = prune_rocm_search_space( num_tokens, shard_intermediate_size, @@ -574,6 +619,7 @@ class BenchmarkWorker: dtype, use_fp8_w8a8, use_int8_w8a16, + use_int4_w4a16, num_iters=20, block_quant_shape=block_quant_shape, use_deep_gemm=use_deep_gemm, @@ -621,6 +667,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: else {} ), **({"kpack": config["kpack"]} if "kpack" in config else {}), + **({"SPLIT_K": config["SPLIT_K"]} if "SPLIT_K" in config else {}), } @@ -633,11 +680,15 @@ def save_configs( dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, block_quant_shape: list[int], save_dir: str, ) -> None: dtype_str = _get_config_dtype_str( - dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + dtype, + use_int8_w8a16=use_int8_w8a16, + use_fp8_w8a8=use_fp8_w8a8, + use_int4_w4a16=use_int4_w4a16, ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which @@ -739,6 +790,38 @@ def get_model_params(config): return E, topk, intermediate_size, hidden_size +def get_quantization_group_size(config) -> int | None: + """Extract the quantization group size from the HF model config. + + This reads directly from the HuggingFace config object (as returned by + ``get_config()``), not from vLLM's quantization config classes. + + Supports AWQ/GPTQ-style configs (direct 'group_size' key) and + compressed-tensors configs (nested inside 'config_groups'). + """ + quantization_config = getattr(config, "quantization_config", {}) + if not isinstance(quantization_config, dict): + return None + # AWQ / GPTQ style: group_size is a top-level key + gs = quantization_config.get("group_size") + if gs is not None: + return gs + # compressed-tensors style: group_size is nested in config_groups + config_groups = quantization_config.get("config_groups", {}) + if not isinstance(config_groups, dict): + return None + for group_cfg in config_groups.values(): + if not isinstance(group_cfg, dict): + continue + weights = group_cfg.get("weights", {}) + if not isinstance(weights, dict): + continue + gs = weights.get("group_size") + if gs is not None: + return gs + return None + + def main(args: argparse.Namespace): print(args) @@ -757,7 +840,20 @@ def main(args: argparse.Namespace): dtype = torch.float16 if current_platform.is_rocm() else config.dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" + use_int4_w4a16 = args.dtype == "int4_w4a16" block_quant_shape = get_weight_block_size_safety(config) + if use_int4_w4a16: + group_size = get_quantization_group_size(config) + if group_size is None: + raise ValueError( + "Could not determine group_size from model config. " + "The model's quantization_config must contain a 'group_size' " + "field (AWQ/GPTQ) or 'config_groups.*.weights.group_size' " + "(compressed-tensors)." + ) + # For int4_w4a16, block_shape = [0, group_size] + # block_shape[0]=0 means no block quantization on N dimension + block_quant_shape = [0, group_size] if args.batch_size is None: batch_sizes = [ @@ -811,8 +907,20 @@ def main(args: argparse.Namespace): return ray.get(outputs) if args.tune: - is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) - search_space = get_configs_compute_bound(is_fp16, block_quant_shape) + # int4_w4a16 weights are uint8-packed, not fp16; treat like fp8 for + # search space generation (no matrix_instr_nonkdim/kpack exploration). + is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16) + # For int4_w4a16, the group_size constraint on BLOCK_SIZE_K does not + # apply: the gptq_awq kernel handles arbitrary BLOCK_SIZE_K regardless + # of group_size. Skip block_quant_shape filtering to keep the full + # search space (e.g. BLOCK_SIZE_K=64 with group_size=128). + tune_block_quant_shape = None if use_int4_w4a16 else block_quant_shape + search_space = get_configs_compute_bound(is_fp16, tune_block_quant_shape) + if use_int4_w4a16: + # SPLIT_K is a required kernel constexpr for gptq_awq kernel; + # only SPLIT_K=1 is used at runtime, so fix it during tuning. + for cfg in search_space: + cfg["SPLIT_K"] = 1 print(f"Start tuning over {len(search_space)} configurations...") if use_deep_gemm: raise ValueError( @@ -832,6 +940,7 @@ def main(args: argparse.Namespace): dtype, use_fp8_w8a8, use_int8_w8a16, + use_int4_w4a16, search_space, block_quant_shape, use_deep_gemm, @@ -851,6 +960,7 @@ def main(args: argparse.Namespace): dtype, use_fp8_w8a8, use_int8_w8a16, + use_int4_w4a16, block_quant_shape, args.save_dir, ) @@ -869,6 +979,7 @@ def main(args: argparse.Namespace): dtype, use_fp8_w8a8, use_int8_w8a16, + use_int4_w4a16, block_quant_shape, use_deep_gemm, ) @@ -891,7 +1002,10 @@ if __name__ == "__main__": ) parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true") parser.add_argument( - "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" + "--dtype", + type=str, + choices=["auto", "fp8_w8a8", "int8_w8a16", "int4_w4a16"], + default="auto", ) parser.add_argument("--use-deep-gemm", action="store_true") parser.add_argument( diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=Radeon_8060S_Graphics,dtype=int4_w4a16.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=Radeon_8060S_Graphics,dtype=int4_w4a16.json new file mode 100644 index 000000000..479bff1c2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=Radeon_8060S_Graphics,dtype=int4_w4a16.json @@ -0,0 +1,63 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "SPLIT_K": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "SPLIT_K": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "SPLIT_K": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "SPLIT_K": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + } +}