Add option to use DeepGemm contiguous grouped gemm kernel for fused MoE operations. (#13932)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-04-01 12:07:43 -04:00
committed by GitHub
parent a57a3044aa
commit e59ca942f5
6 changed files with 773 additions and 114 deletions

View File

@@ -30,19 +30,18 @@ class BenchmarkConfig(TypedDict):
num_stages: int
def benchmark_config(
config: BenchmarkConfig,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
block_quant_shape: List[int] = None,
) -> float:
def benchmark_config(config: BenchmarkConfig,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
block_quant_shape: List[int] = None,
use_deep_gemm: bool = False) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16:
@@ -115,22 +114,41 @@ def benchmark_config(
def run():
from vllm.model_executor.layers.fused_moe import override_config
with override_config(config):
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
)
if use_deep_gemm:
topk_weights, topk_ids = fused_topk(x, input_gating, topk,
False)
return fused_experts(
x,
w1,
w2,
topk_weights,
topk_ids,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
allow_deep_gemm=True,
)
else:
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
)
# JIT compilation & warmup
run()
@@ -366,6 +384,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_quant_shape: List[int] = None,
use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype,
@@ -396,7 +415,8 @@ class BenchmarkWorker:
use_fp8_w8a8,
use_int8_w8a16,
num_iters=100,
block_quant_shape=block_quant_shape)
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm)
return config, kernel_time
def tune(
@@ -411,6 +431,7 @@ class BenchmarkWorker:
use_int8_w8a16: bool,
search_space: list[dict[str, int]],
block_quant_shape: list[int],
use_deep_gemm: bool,
) -> dict[str, int]:
best_config = None
best_time = float("inf")
@@ -436,7 +457,8 @@ class BenchmarkWorker:
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20,
block_quant_shape=block_quant_shape)
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
@@ -550,6 +572,8 @@ def main(args: argparse.Namespace):
else:
batch_sizes = [args.batch_size]
use_deep_gemm = bool(args.use_deep_gemm)
ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
@@ -572,10 +596,10 @@ def main(args: argparse.Namespace):
start = time.time()
configs = _distribute(
"tune",
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape)
for batch_size in batch_sizes])
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space,
block_quant_shape, use_deep_gemm)
for batch_size in batch_sizes])
best_configs = {
M: sort_config(config)
for M, config in zip(batch_sizes, configs)
@@ -589,7 +613,7 @@ def main(args: argparse.Namespace):
outputs = _distribute(
"benchmark",
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
use_fp8_w8a8, use_int8_w8a16, block_quant_shape)
use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm)
for batch_size in batch_sizes])
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
@@ -611,6 +635,7 @@ if __name__ == "__main__":
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16"],
default="auto")
parser.add_argument("--use-deep-gemm", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")