212 lines
5.1 KiB
Python
212 lines
5.1 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
|
||
|
|
from collections.abc import Callable, Iterable
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from itertools import product
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
import torch.utils.benchmark as TBenchmark
|
||
|
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||
|
|
from tqdm import tqdm
|
||
|
|
|
||
|
|
import vllm._custom_ops as ops
|
||
|
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||
|
|
per_token_group_quant_fp8,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class bench_params_t:
|
||
|
|
num_tokens: int
|
||
|
|
hidden_size: int
|
||
|
|
dtype: torch.dtype
|
||
|
|
group_size: int # Changed from list[int] to int
|
||
|
|
|
||
|
|
def description(self):
|
||
|
|
return (
|
||
|
|
f"N {self.num_tokens} "
|
||
|
|
f"x D {self.hidden_size} "
|
||
|
|
f"x DT {self.dtype} "
|
||
|
|
f"x GS {self.group_size}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def get_bench_params() -> list[bench_params_t]:
|
||
|
|
"""Test configurations covering common model sizes."""
|
||
|
|
NUM_TOKENS = [16, 128, 512, 2048]
|
||
|
|
HIDDEN_SIZES = [1024, 2048, 4096, 5120, 14336] # Common FFN sizes
|
||
|
|
DTYPES = [torch.float16, torch.bfloat16]
|
||
|
|
GROUP_SIZES = [64, 128] # Changed from [[1, 64], [1, 128]]
|
||
|
|
|
||
|
|
combinations = product(NUM_TOKENS, HIDDEN_SIZES, DTYPES, GROUP_SIZES)
|
||
|
|
bench_params = list(
|
||
|
|
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
|
||
|
|
)
|
||
|
|
return bench_params
|
||
|
|
|
||
|
|
|
||
|
|
# Reference implementations
|
||
|
|
def unfused_fp8_impl(
|
||
|
|
x: torch.Tensor,
|
||
|
|
quant_dtype: torch.dtype,
|
||
|
|
group_size: int, # Changed from list[int]
|
||
|
|
):
|
||
|
|
"""Unfused: SiLU+Mul then per-tensor quantize."""
|
||
|
|
hidden = x.shape[-1] // 2
|
||
|
|
gate, up = x.split(hidden, dim=-1)
|
||
|
|
|
||
|
|
# SiLU(gate) * up
|
||
|
|
silu_out = F.silu(gate) * up
|
||
|
|
|
||
|
|
# Per-tensor quantize (no group_size used here)
|
||
|
|
silu_out, _ = ops.scaled_fp8_quant(silu_out)
|
||
|
|
|
||
|
|
|
||
|
|
def unfused_groupwise_fp8_impl(
|
||
|
|
x: torch.Tensor,
|
||
|
|
quant_dtype: torch.dtype,
|
||
|
|
group_size: int, # Changed from list[int]
|
||
|
|
):
|
||
|
|
"""Unfused: SiLU+Mul then group-wise quantize."""
|
||
|
|
hidden = x.shape[-1] // 2
|
||
|
|
gate, up = x.split(hidden, dim=-1)
|
||
|
|
|
||
|
|
# SiLU(gate) * up
|
||
|
|
silu_out = F.silu(gate) * up
|
||
|
|
|
||
|
|
# Group quantize - use group_size directly
|
||
|
|
silu_out, _ = per_token_group_quant_fp8(
|
||
|
|
silu_out, group_size=group_size, use_ue8m0=False
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def fused_impl(
|
||
|
|
x: torch.Tensor,
|
||
|
|
quant_dtype: torch.dtype,
|
||
|
|
group_size: int,
|
||
|
|
):
|
||
|
|
"""Fused: SiLU+Mul+Block Quantization in single kernel."""
|
||
|
|
out, _ = ops.silu_and_mul_per_block_quant(
|
||
|
|
x,
|
||
|
|
group_size=group_size,
|
||
|
|
quant_dtype=quant_dtype,
|
||
|
|
is_scale_transposed=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# Bench functions
|
||
|
|
def bench_fn(
|
||
|
|
x: torch.Tensor,
|
||
|
|
quant_dtype: torch.dtype,
|
||
|
|
group_size: int,
|
||
|
|
label: str,
|
||
|
|
sub_label: str,
|
||
|
|
fn: Callable,
|
||
|
|
description: str,
|
||
|
|
) -> TMeasurement:
|
||
|
|
min_run_time = 1
|
||
|
|
|
||
|
|
globals = {
|
||
|
|
"x": x,
|
||
|
|
"quant_dtype": quant_dtype,
|
||
|
|
"group_size": group_size,
|
||
|
|
"fn": fn,
|
||
|
|
}
|
||
|
|
return TBenchmark.Timer(
|
||
|
|
stmt="fn(x, quant_dtype, group_size)",
|
||
|
|
globals=globals,
|
||
|
|
label=label,
|
||
|
|
sub_label=sub_label,
|
||
|
|
description=description,
|
||
|
|
).blocked_autorange(min_run_time=min_run_time)
|
||
|
|
|
||
|
|
|
||
|
|
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
|
||
|
|
"""Run benchmarks for all implementations."""
|
||
|
|
# Make inputs: [num_tokens, hidden_size * 2] for [gate || up]
|
||
|
|
scale = 1 / params.hidden_size
|
||
|
|
x = (
|
||
|
|
torch.randn(
|
||
|
|
params.num_tokens,
|
||
|
|
params.hidden_size * 2,
|
||
|
|
dtype=params.dtype,
|
||
|
|
device="cuda",
|
||
|
|
)
|
||
|
|
* scale
|
||
|
|
)
|
||
|
|
|
||
|
|
timers = []
|
||
|
|
|
||
|
|
# Unfused per-tensor FP8
|
||
|
|
timers.append(
|
||
|
|
bench_fn(
|
||
|
|
x,
|
||
|
|
torch.float8_e4m3fn,
|
||
|
|
params.group_size,
|
||
|
|
label,
|
||
|
|
sub_label,
|
||
|
|
unfused_fp8_impl,
|
||
|
|
"unfused_fp8_impl",
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Unfused group-wise FP8
|
||
|
|
timers.append(
|
||
|
|
bench_fn(
|
||
|
|
x,
|
||
|
|
torch.float8_e4m3fn,
|
||
|
|
params.group_size,
|
||
|
|
label,
|
||
|
|
sub_label,
|
||
|
|
unfused_groupwise_fp8_impl,
|
||
|
|
"unfused_groupwise_fp8_impl",
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Fused group-wise FP8
|
||
|
|
timers.append(
|
||
|
|
bench_fn(
|
||
|
|
x,
|
||
|
|
torch.float8_e4m3fn,
|
||
|
|
params.group_size,
|
||
|
|
label,
|
||
|
|
sub_label,
|
||
|
|
fused_impl,
|
||
|
|
"fused_groupwise_fp8_impl",
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
return timers
|
||
|
|
|
||
|
|
|
||
|
|
def print_timers(timers: Iterable[TMeasurement]):
|
||
|
|
compare = TBenchmark.Compare(timers)
|
||
|
|
compare.print()
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
torch.set_default_device("cuda")
|
||
|
|
bench_params = get_bench_params()
|
||
|
|
|
||
|
|
print(f"Running {len(bench_params)} benchmark configurations...")
|
||
|
|
print(
|
||
|
|
f"This will take approximately {len(bench_params) * 3} seconds (1s per variant)"
|
||
|
|
)
|
||
|
|
print()
|
||
|
|
|
||
|
|
timers = []
|
||
|
|
for bp in tqdm(bench_params):
|
||
|
|
result_timers = bench(bp, "silu-mul-block-quant", bp.description())
|
||
|
|
timers.extend(result_timers)
|
||
|
|
|
||
|
|
print("\n" + "=" * 80)
|
||
|
|
print("FINAL COMPARISON - ALL RESULTS")
|
||
|
|
print("=" * 80)
|
||
|
|
print_timers(timers)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|