Feature/silu block quant fusion v1 (#32996)
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
This commit is contained in:
@@ -340,7 +340,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
|
||||
211
benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py
Normal file
211
benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class bench_params_t:
|
||||
num_tokens: int
|
||||
hidden_size: int
|
||||
dtype: torch.dtype
|
||||
group_size: int # Changed from list[int] to int
|
||||
|
||||
def description(self):
|
||||
return (
|
||||
f"N {self.num_tokens} "
|
||||
f"x D {self.hidden_size} "
|
||||
f"x DT {self.dtype} "
|
||||
f"x GS {self.group_size}"
|
||||
)
|
||||
|
||||
|
||||
def get_bench_params() -> list[bench_params_t]:
|
||||
"""Test configurations covering common model sizes."""
|
||||
NUM_TOKENS = [16, 128, 512, 2048]
|
||||
HIDDEN_SIZES = [1024, 2048, 4096, 5120, 14336] # Common FFN sizes
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
GROUP_SIZES = [64, 128] # Changed from [[1, 64], [1, 128]]
|
||||
|
||||
combinations = product(NUM_TOKENS, HIDDEN_SIZES, DTYPES, GROUP_SIZES)
|
||||
bench_params = list(
|
||||
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
|
||||
)
|
||||
return bench_params
|
||||
|
||||
|
||||
# Reference implementations
|
||||
def unfused_fp8_impl(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int, # Changed from list[int]
|
||||
):
|
||||
"""Unfused: SiLU+Mul then per-tensor quantize."""
|
||||
hidden = x.shape[-1] // 2
|
||||
gate, up = x.split(hidden, dim=-1)
|
||||
|
||||
# SiLU(gate) * up
|
||||
silu_out = F.silu(gate) * up
|
||||
|
||||
# Per-tensor quantize (no group_size used here)
|
||||
silu_out, _ = ops.scaled_fp8_quant(silu_out)
|
||||
|
||||
|
||||
def unfused_groupwise_fp8_impl(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int, # Changed from list[int]
|
||||
):
|
||||
"""Unfused: SiLU+Mul then group-wise quantize."""
|
||||
hidden = x.shape[-1] // 2
|
||||
gate, up = x.split(hidden, dim=-1)
|
||||
|
||||
# SiLU(gate) * up
|
||||
silu_out = F.silu(gate) * up
|
||||
|
||||
# Group quantize - use group_size directly
|
||||
silu_out, _ = per_token_group_quant_fp8(
|
||||
silu_out, group_size=group_size, use_ue8m0=False
|
||||
)
|
||||
|
||||
|
||||
def fused_impl(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int,
|
||||
):
|
||||
"""Fused: SiLU+Mul+Block Quantization in single kernel."""
|
||||
out, _ = ops.silu_and_mul_per_block_quant(
|
||||
x,
|
||||
group_size=group_size,
|
||||
quant_dtype=quant_dtype,
|
||||
is_scale_transposed=False,
|
||||
)
|
||||
|
||||
|
||||
# Bench functions
|
||||
def bench_fn(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
fn: Callable,
|
||||
description: str,
|
||||
) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
"x": x,
|
||||
"quant_dtype": quant_dtype,
|
||||
"group_size": group_size,
|
||||
"fn": fn,
|
||||
}
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn(x, quant_dtype, group_size)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
|
||||
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
|
||||
"""Run benchmarks for all implementations."""
|
||||
# Make inputs: [num_tokens, hidden_size * 2] for [gate || up]
|
||||
scale = 1 / params.hidden_size
|
||||
x = (
|
||||
torch.randn(
|
||||
params.num_tokens,
|
||||
params.hidden_size * 2,
|
||||
dtype=params.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
* scale
|
||||
)
|
||||
|
||||
timers = []
|
||||
|
||||
# Unfused per-tensor FP8
|
||||
timers.append(
|
||||
bench_fn(
|
||||
x,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
unfused_fp8_impl,
|
||||
"unfused_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
# Unfused group-wise FP8
|
||||
timers.append(
|
||||
bench_fn(
|
||||
x,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
unfused_groupwise_fp8_impl,
|
||||
"unfused_groupwise_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
# Fused group-wise FP8
|
||||
timers.append(
|
||||
bench_fn(
|
||||
x,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
fused_impl,
|
||||
"fused_groupwise_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def print_timers(timers: Iterable[TMeasurement]):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_default_device("cuda")
|
||||
bench_params = get_bench_params()
|
||||
|
||||
print(f"Running {len(bench_params)} benchmark configurations...")
|
||||
print(
|
||||
f"This will take approximately {len(bench_params) * 3} seconds (1s per variant)"
|
||||
)
|
||||
print()
|
||||
|
||||
timers = []
|
||||
for bp in tqdm(bench_params):
|
||||
result_timers = bench(bp, "silu-mul-block-quant", bp.description())
|
||||
timers.extend(result_timers)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("FINAL COMPARISON - ALL RESULTS")
|
||||
print("=" * 80)
|
||||
print_timers(timers)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -142,6 +142,12 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
std::optional<torch::Tensor> residual,
|
||||
int64_t group_size, bool is_scale_transposed);
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
bool is_scale_transposed);
|
||||
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
std::optional<torch::Tensor> key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
|
||||
169
csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu
Normal file
169
csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu
Normal file
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "quant_conversions.cuh"
|
||||
#include "../w8a8/fp8/common.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Logic: one thread block per (token, group) pair
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool is_scale_transposed,
|
||||
int32_t group_size>
|
||||
__global__ void silu_and_mul_per_block_quant_kernel(
|
||||
scalar_out_t* __restrict__ out, // Output: [num_tokens, hidden_size] in
|
||||
// FP8/INT8
|
||||
float* __restrict__ scales, // Output: [num_tokens, hidden_size /
|
||||
// group_size] or [hidden_size / group_size,
|
||||
// num_tokens]
|
||||
scalar_t const* __restrict__ input, // Input: [num_tokens, hidden_size * 2]
|
||||
float const* scale_ub, // Optional scale upper bound
|
||||
int32_t const hidden_size // Output hidden size (input is 2x this)
|
||||
) {
|
||||
static_assert((group_size & (group_size - 1)) == 0,
|
||||
"group_size must be a power of 2 for correct reduction");
|
||||
|
||||
// Grid: (num_tokens, num_groups)
|
||||
int const token_idx = blockIdx.x;
|
||||
int const group_idx = blockIdx.y;
|
||||
int const tid = threadIdx.x; // tid in [0, group_size)
|
||||
int const num_tokens = gridDim.x;
|
||||
|
||||
// Input layout: [gate || up] concatenated along last dimension
|
||||
int const input_stride = hidden_size * 2;
|
||||
int const group_start = group_idx * group_size;
|
||||
|
||||
// Pointers to this token's data
|
||||
scalar_t const* token_input_gate =
|
||||
input + token_idx * input_stride + group_start;
|
||||
scalar_t const* token_input_up = token_input_gate + hidden_size;
|
||||
scalar_out_t* token_output = out + token_idx * hidden_size + group_start;
|
||||
|
||||
// Scale pointer for this group
|
||||
int const num_groups = gridDim.y;
|
||||
float* group_scale_ptr = is_scale_transposed
|
||||
? scales + group_idx * num_tokens + token_idx
|
||||
: scales + token_idx * num_groups + group_idx;
|
||||
|
||||
// Shared memory for reduction (compile-time sized)
|
||||
__shared__ float shared_max[group_size];
|
||||
|
||||
// Step 1: Each thread loads one element, computes SiLU, stores in register
|
||||
float gate = static_cast<float>(token_input_gate[tid]);
|
||||
float up = static_cast<float>(token_input_up[tid]);
|
||||
|
||||
// Compute SiLU(gate) * up
|
||||
float sigmoid_gate = 1.0f / (1.0f + expf(-gate));
|
||||
float silu_gate = gate * sigmoid_gate;
|
||||
float result = silu_gate * up; // Keep in register
|
||||
|
||||
// Step 2: Reduce to find group max
|
||||
shared_max[tid] = fabsf(result);
|
||||
__syncthreads();
|
||||
|
||||
// Power-of-2 reduction (group_size guaranteed to be power of 2)
|
||||
#pragma unroll
|
||||
for (int stride = group_size / 2; stride > 0; stride >>= 1) {
|
||||
if (tid < stride) {
|
||||
shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + stride]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Step 3: Compute scale (thread 0), broadcast via shared memory
|
||||
if (tid == 0) {
|
||||
float group_max = shared_max[0];
|
||||
|
||||
float const quant_range = quant_type_max_v<scalar_out_t>;
|
||||
float group_scale = group_max / quant_range;
|
||||
|
||||
// Apply scale upper bound if provided
|
||||
if (scale_ub != nullptr) {
|
||||
group_scale = fminf(group_scale, *scale_ub);
|
||||
}
|
||||
|
||||
// Use minimum safe scaling factor
|
||||
group_scale = fmaxf(group_scale, min_scaling_factor<scalar_out_t>::val());
|
||||
|
||||
// Store scale to global memory
|
||||
*group_scale_ptr = group_scale;
|
||||
|
||||
// Reuse shared_max[0] to broadcast scale
|
||||
shared_max[0] = group_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float group_scale = shared_max[0];
|
||||
|
||||
// Step 4: Quantize and write output
|
||||
token_output[tid] =
|
||||
vllm::ScaledQuant<scalar_out_t, false>::quant_fn(result, group_scale);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
bool is_scale_transposed) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
TORCH_CHECK(
|
||||
input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16,
|
||||
"Input must be FP16 or BF16");
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32");
|
||||
TORCH_CHECK(group_size == 128 || group_size == 64,
|
||||
"Unsupported group size: ", group_size);
|
||||
|
||||
if (scale_ub.has_value()) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||
}
|
||||
|
||||
int32_t hidden_size = out.size(-1);
|
||||
auto num_tokens = input.size(0);
|
||||
int32_t num_groups = hidden_size / group_size;
|
||||
|
||||
TORCH_CHECK(input.size(-1) == hidden_size * 2,
|
||||
"input last dim must be 2x output hidden_size");
|
||||
TORCH_CHECK(hidden_size % group_size == 0,
|
||||
"hidden_size must be divisible by group_size");
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(num_tokens, num_groups);
|
||||
dim3 block(group_size);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "silu_and_mul_per_block_quant", [&] {
|
||||
using scalar_in_t = scalar_t;
|
||||
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "silu_and_mul_per_block_quant", [&] {
|
||||
using scalar_out_t = scalar_t;
|
||||
|
||||
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
|
||||
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
|
||||
vllm::silu_and_mul_per_block_quant_kernel<
|
||||
scalar_in_t, scalar_out_t, transpose_scale, gs>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_out_t>(),
|
||||
scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||
: nullptr,
|
||||
hidden_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -2,7 +2,6 @@
|
||||
#include "cuda_utils.h"
|
||||
#include "ops.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <torch/version.h>
|
||||
|
||||
@@ -110,6 +109,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||
|
||||
// Fused SiLU+Mul + per-block quantization
|
||||
ops.def(
|
||||
"silu_and_mul_per_block_quant("
|
||||
"Tensor! out, "
|
||||
"Tensor input, "
|
||||
"Tensor! scales, "
|
||||
"int group_size, "
|
||||
"Tensor? scale_ub=None, "
|
||||
"bool is_scale_transposed=False) -> ()");
|
||||
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
|
||||
&silu_and_mul_per_block_quant);
|
||||
|
||||
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ The table below lists the quantization schemes supported by each fusion on each
|
||||
| `enable_sp` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — |
|
||||
| `fuse_gemm_comms` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — |
|
||||
| `fuse_norm_quant` | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | — | FP8 static, FP8 per-token, FP8 per-group |
|
||||
| `fuse_act_quant` | FP8 static, NVFP4 | FP8 static | FP8 static | — | FP8 per-group |
|
||||
| `fuse_act_quant` | FP8 static, NVFP4 | FP8 static, FP8 per-group (128/64) | FP8 static, FP8 per-group (128/64) | — | FP8 per-group |
|
||||
| `fuse_act_padding` | — | — | — | — | FP16/BF16 |
|
||||
|
||||
\* `fuse_attn_quant` support depends on the attention backend in use; not all backends support
|
||||
@@ -305,6 +305,7 @@ Note that AITER fusions are in a separate pass in `vllm.compilation.passes.fusio
|
||||
Supported quantization scheme/hardware combinations:
|
||||
|
||||
- FP8 static per-tensor: CUDA & HIP kernel
|
||||
- FP8 dynamic per-group (128/64): CUDA kernel (sm89+, not active when DeepGemm is used on sm100+)
|
||||
- NVFP4 dynamic: CUDA sm100+ only with FlashInfer
|
||||
- FP8 per-token-group (128): ROCm AITER only
|
||||
|
||||
@@ -313,6 +314,7 @@ Supported quantization scheme/hardware combinations:
|
||||
- Pass: [`vllm/compilation/passes/fusion/act_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/act_quant_fusion.py)
|
||||
- ROCm AITER pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py)
|
||||
- CUDA/HIP kernels: [`csrc/quantization/`](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/)
|
||||
- Fused SiLU+Mul+BlockQuant kernel: [`csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu`](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu)
|
||||
|
||||
### RMSNorm + Padding (`fuse_act_padding`)
|
||||
|
||||
|
||||
@@ -150,9 +150,8 @@ deepseek_v3_fp8 = ModelFusionInfo(
|
||||
# - post_attn_layernorm + MLP
|
||||
# 2 per MoE layer (remaining) due to MoE wrapping
|
||||
rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers
|
||||
# TODO silu+block quant
|
||||
# act_quant_fusion=min(3, n_layers), # dense layers only
|
||||
act_quant_fusion=0,
|
||||
# silu+block quant
|
||||
act_quant_fusion=min(3, n_layers), # dense layers only
|
||||
# MLA attn + quant not supported yet:
|
||||
# https://github.com/vllm-project/vllm/issues/35792
|
||||
attn_quant_fusion=0,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -34,13 +35,16 @@ from vllm.model_executor.kernels.linear import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
@@ -165,6 +169,48 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
||||
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
|
||||
|
||||
|
||||
class TestSiluMulBlockQuantModel(torch.nn.Module):
|
||||
quant_key = kFp8Dynamic128Sym
|
||||
|
||||
def __init__(self, hidden_size: int, is_scale_transposed: bool = False, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.is_scale_transposed = is_scale_transposed
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=False,
|
||||
group_shape=GroupShape(1, 128),
|
||||
column_major_scales=is_scale_transposed,
|
||||
compile_native=False,
|
||||
)
|
||||
|
||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||
self.enable_quant_fp8_custom_op = self.quant_fp8.enabled()
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
out, scale = self.quant_fp8(y)
|
||||
group_size = self.quant_key.scale.group_shape[1]
|
||||
scale_expanded = scale.repeat_interleave(group_size, dim=1)
|
||||
dequant = out.to(dtype=torch.float32) * scale_expanded
|
||||
return (dequant,)
|
||||
|
||||
def ops_in_model_before(self):
|
||||
ops = []
|
||||
if self.enable_silu_mul_custom_op:
|
||||
ops.append(SILU_MUL_OP)
|
||||
# When silu custom op is disabled, aten.mul.Tensor also appears
|
||||
# in dequant code, so we skip checking it to avoid false positives.
|
||||
ops.append(
|
||||
QUANT_OPS[self.quant_key]
|
||||
if self.enable_quant_fp8_custom_op
|
||||
else torch.ops.aten.reciprocal.default
|
||||
)
|
||||
return ops
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [FUSED_OPS[self.quant_key]]
|
||||
|
||||
|
||||
ROCM_KERNELS = [ROCmFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel]
|
||||
CUDA_KERNELS = [
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
@@ -200,6 +246,19 @@ TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS
|
||||
not current_platform.is_rocm(), reason="ROCm only"
|
||||
),
|
||||
),
|
||||
# Block quant fusion for per-group FP8 (CUDA only).
|
||||
*[
|
||||
pytest.param(
|
||||
partial(TestSiluMulBlockQuantModel, is_scale_transposed=transposed),
|
||||
True,
|
||||
None,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="CUDA only"
|
||||
),
|
||||
id=f"TestSiluMulBlockQuant-transposed={transposed}",
|
||||
)
|
||||
for transposed in [False, True]
|
||||
],
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
@@ -213,6 +272,7 @@ def test_fusion_silu_and_mul_quant(
|
||||
TestSiluMulFp8QuantModel
|
||||
| TestSiluMulNvfp4QuantModel
|
||||
| TestSiluMulGroupFp8QuantModel
|
||||
| TestSiluMulBlockQuantModel
|
||||
],
|
||||
enable_silu_mul_custom_op: bool,
|
||||
enable_quant_fp8_custom_op: bool,
|
||||
@@ -223,6 +283,12 @@ def test_fusion_silu_and_mul_quant(
|
||||
pytest.skip("NVFP4 is not supported on this GPU.")
|
||||
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
|
||||
pytest.skip("AITER is not supported on this GPU.")
|
||||
if (
|
||||
isinstance(model_class, partial)
|
||||
and model_class.func is TestSiluMulBlockQuantModel
|
||||
and is_deep_gemm_supported()
|
||||
):
|
||||
pytest.skip("SiluMul+BlockQuant fusion not applicable with DeepGemm")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
@@ -269,11 +335,13 @@ def test_fusion_silu_and_mul_quant(
|
||||
result2 = model2(x)
|
||||
|
||||
# Check that it gives the same answer
|
||||
if model_class == TestSiluMulFp8QuantModel:
|
||||
if isinstance(model, TestSiluMulFp8QuantModel):
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||
elif isinstance(model, TestSiluMulNvfp4QuantModel):
|
||||
atol, rtol = 1e-1, 1e-1
|
||||
elif model_class == TestSiluMulGroupFp8QuantModel:
|
||||
elif isinstance(
|
||||
model, (TestSiluMulGroupFp8QuantModel, TestSiluMulBlockQuantModel)
|
||||
):
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(
|
||||
|
||||
189
tests/kernels/core/test_fused_silu_mul_block_quant.py
Normal file
189
tests/kernels/core/test_fused_silu_mul_block_quant.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
QUANT_DTYPES = [torch.float8_e4m3fn, torch.int8]
|
||||
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
|
||||
NUM_TOKENS_HIDDEN_SIZES = [
|
||||
*[(1, i) for i in [64, *VEC_HIDDEN_SIZES, 2048, 5120]],
|
||||
*[(16, i) for i in [64, *VEC_HIDDEN_SIZES, 5120]],
|
||||
*[(128, i) for i in [64, *VEC_HIDDEN_SIZES]],
|
||||
*[(512, i) for i in [64, 5120]],
|
||||
]
|
||||
SCALE_UBS = [False]
|
||||
GROUP_SIZES = [64, 128]
|
||||
IS_SCALE_TRANSPOSED = [False, True]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
def ref_silu_and_mul_per_block_quant(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Reference implementation: unfused SiLU+Mul then group quantization."""
|
||||
hidden = x.shape[-1] // 2
|
||||
gate, up = x.split(hidden, dim=-1)
|
||||
silu_out = F.silu(gate) * up
|
||||
|
||||
if quant_dtype == current_platform.fp8_dtype():
|
||||
return per_token_group_quant_fp8(
|
||||
silu_out, group_size=group_size, use_ue8m0=False
|
||||
)
|
||||
elif quant_dtype == torch.int8:
|
||||
return per_token_group_quant_int8(silu_out, group_size=group_size)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_dtype: {quant_dtype}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
|
||||
@pytest.mark.parametrize("group_size", GROUP_SIZES)
|
||||
@pytest.mark.parametrize("is_scale_transposed", IS_SCALE_TRANSPOSED)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_and_mul_per_block_quant(
|
||||
default_vllm_config,
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
has_scale_ub: bool,
|
||||
dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int,
|
||||
is_scale_transposed: bool,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
"""Test SiLU+Mul+Block Quantization kernel correctness."""
|
||||
torch.random.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
if hidden_size % group_size != 0:
|
||||
return
|
||||
|
||||
if has_scale_ub:
|
||||
pytest.skip("Scale upper bound not yet supported")
|
||||
|
||||
scale = 1 / hidden_size
|
||||
x = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device=device) * scale
|
||||
|
||||
# Reference implementation
|
||||
ref_out, ref_scales = ref_silu_and_mul_per_block_quant(x, quant_dtype, group_size)
|
||||
|
||||
# Fused kernel implementation
|
||||
ops_out, ops_scales = ops.silu_and_mul_per_block_quant(
|
||||
x, group_size, quant_dtype, None, is_scale_transposed
|
||||
)
|
||||
|
||||
# Check for NaN/Inf
|
||||
assert not torch.isnan(ops_out.float()).any(), "Kernel output contains NaN"
|
||||
assert not torch.isinf(ops_out.float()).any(), "Kernel output contains Inf"
|
||||
assert not torch.isnan(ops_scales).any(), "Kernel scales contain NaN"
|
||||
assert not torch.isinf(ops_scales).any(), "Kernel scales contain Inf"
|
||||
|
||||
# Check dtypes
|
||||
assert ref_out.dtype == quant_dtype
|
||||
assert ops_out.dtype == quant_dtype
|
||||
|
||||
# Check scales match
|
||||
torch.testing.assert_close(ref_scales, ops_scales, rtol=1e-5, atol=1e-5)
|
||||
|
||||
# Check output correctness via dequantized values
|
||||
ref_scales_expanded = ref_scales.repeat_interleave(group_size, dim=1)
|
||||
ops_scales_expanded = ops_scales.repeat_interleave(group_size, dim=1)
|
||||
ref_deq = ref_out.to(dtype=torch.float32) * ref_scales_expanded
|
||||
ops_deq = ops_out.to(dtype=torch.float32) * ops_scales_expanded
|
||||
torch.testing.assert_close(ref_deq, ops_deq, atol=5e-2, rtol=5e-2)
|
||||
|
||||
# opcheck
|
||||
output = torch.empty(num_tokens, hidden_size, device=device, dtype=quant_dtype)
|
||||
num_groups = hidden_size // group_size
|
||||
if is_scale_transposed:
|
||||
scales = torch.empty(num_groups, num_tokens, device=device, dtype=torch.float32)
|
||||
else:
|
||||
scales = torch.empty(num_tokens, num_groups, device=device, dtype=torch.float32)
|
||||
opcheck(
|
||||
torch.ops._C.silu_and_mul_per_block_quant,
|
||||
(output, x, scales, group_size, None, is_scale_transposed),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("hidden_size", [4096])
|
||||
@pytest.mark.parametrize("num_tokens", [128])
|
||||
@pytest.mark.parametrize("group_size", [128])
|
||||
def test_silu_block_quant_shapes(
|
||||
default_vllm_config,
|
||||
dtype: torch.dtype,
|
||||
hidden_size: int,
|
||||
num_tokens: int,
|
||||
group_size: int,
|
||||
):
|
||||
"""Test that output shapes are correct."""
|
||||
torch.set_default_device("cuda")
|
||||
x = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device="cuda")
|
||||
|
||||
# Row-major scales
|
||||
out, scales = ops.silu_and_mul_per_block_quant(
|
||||
x,
|
||||
group_size=group_size,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
is_scale_transposed=False,
|
||||
)
|
||||
assert out.shape == (num_tokens, hidden_size)
|
||||
assert scales.shape == (num_tokens, hidden_size // group_size)
|
||||
|
||||
# Column-major scales (logical shape same after .t() in _custom_ops)
|
||||
out, scales = ops.silu_and_mul_per_block_quant(
|
||||
x,
|
||||
group_size=group_size,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
is_scale_transposed=True,
|
||||
)
|
||||
assert out.shape == (num_tokens, hidden_size)
|
||||
assert scales.shape == (num_tokens, hidden_size // group_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("batch_size", [1, 16, 256])
|
||||
@pytest.mark.parametrize("hidden_size", [1024, 5120, 14336])
|
||||
def test_silu_block_quant_edge_cases(
|
||||
default_vllm_config, dtype: torch.dtype, batch_size: int, hidden_size: int
|
||||
):
|
||||
"""Test edge cases: single token, large batch, large hidden size."""
|
||||
torch.set_default_device("cuda")
|
||||
x = torch.randn(batch_size, hidden_size * 2, dtype=dtype, device="cuda")
|
||||
|
||||
out, scales = ops.silu_and_mul_per_block_quant(
|
||||
x,
|
||||
group_size=128,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
is_scale_transposed=False,
|
||||
)
|
||||
|
||||
assert out.shape == (batch_size, hidden_size)
|
||||
assert out.dtype == torch.float8_e4m3fn
|
||||
assert scales.dtype == torch.float32
|
||||
assert not torch.isnan(out.float()).any()
|
||||
assert not torch.isnan(scales).any()
|
||||
assert not torch.isinf(scales).any()
|
||||
@@ -579,6 +579,56 @@ def rms_norm_per_block_quant(
|
||||
return output, scales
|
||||
|
||||
|
||||
# fused silu_and_mul + block quant
|
||||
def silu_and_mul_per_block_quant(
|
||||
input: torch.Tensor,
|
||||
group_size: int, # Changed from list[int]
|
||||
quant_dtype: torch.dtype,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
is_scale_transposed: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert input.ndim == 2, f"input must be 2D [batch, hidden*2], got {input.shape}"
|
||||
assert input.shape[-1] % 2 == 0, (
|
||||
f"input last dim must be even (gate||up layout), got {input.shape[-1]}"
|
||||
)
|
||||
|
||||
# Output is half the width of input (after silu_and_mul)
|
||||
num_tokens = input.shape[0]
|
||||
hidden_size = input.shape[-1] // 2 # Divide by 2 because input is [gate || up]
|
||||
|
||||
# Allocate output tensor (FP8 or INT8)
|
||||
output = torch.empty(
|
||||
(num_tokens, hidden_size), device=input.device, dtype=quant_dtype
|
||||
)
|
||||
|
||||
# Allocate scales tensor
|
||||
num_groups = hidden_size // group_size # Directly use group_size
|
||||
if is_scale_transposed:
|
||||
scales = torch.empty(
|
||||
(num_groups, num_tokens),
|
||||
device=input.device,
|
||||
dtype=torch.float32,
|
||||
).t()
|
||||
else:
|
||||
scales = torch.empty(
|
||||
(num_tokens, num_groups),
|
||||
device=input.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
# Call the C++ kernel
|
||||
torch.ops._C.silu_and_mul_per_block_quant(
|
||||
output,
|
||||
input,
|
||||
scales,
|
||||
group_size, # Pass directly as int
|
||||
scale_ub,
|
||||
is_scale_transposed,
|
||||
)
|
||||
|
||||
return output, scales
|
||||
|
||||
|
||||
# quantization ops
|
||||
# awq
|
||||
def awq_dequantize(
|
||||
|
||||
@@ -17,6 +17,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
@@ -43,6 +45,10 @@ silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
|
||||
|
||||
if current_platform.is_cuda():
|
||||
FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
|
||||
FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
|
||||
|
||||
|
||||
class ActivationQuantPattern(ABC):
|
||||
"""
|
||||
@@ -174,6 +180,102 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SiluMulBlockQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+BlockQuant (FP8 dynamic per-group) Pattern.
|
||||
Supports group_size 128 and 64 via QuantKey.
|
||||
Parameterized on is_scale_transposed for different scale layouts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
is_scale_transposed: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
is_tma_aligned: bool = False,
|
||||
) -> None:
|
||||
super().__init__(quant_key)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
quant_key,
|
||||
has_col_major_scales=is_scale_transposed,
|
||||
is_e8m0=is_e8m0,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
)
|
||||
self.group_size = quant_key.scale.group_shape[1]
|
||||
self.is_scale_transposed = is_scale_transposed
|
||||
self.is_e8m0 = is_e8m0
|
||||
self.is_tma_aligned = is_tma_aligned
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
scale = self.quant_matcher.empty_f32(1, 1)
|
||||
return self.silu_and_mul_matcher.inputs() + [scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
is_scale_transposed = self.is_scale_transposed
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
silu_out = self.silu_and_mul_matcher(input)
|
||||
result = torch.empty(
|
||||
silu_out.shape,
|
||||
device=silu_out.device,
|
||||
dtype=self.quant_dtype,
|
||||
)
|
||||
assert scale is not None
|
||||
finfo = torch.finfo(self.quant_dtype)
|
||||
_, result, scale = auto_functionalized(
|
||||
self.quant_matcher.QUANT_OP,
|
||||
input=silu_out,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.group_size,
|
||||
eps=1e-10,
|
||||
fp8_min=finfo.min,
|
||||
fp8_max=finfo.max,
|
||||
scale_ue8m0=self.is_e8m0,
|
||||
dummy_is_scale_transposed=is_scale_transposed,
|
||||
dummy_is_tma_aligned=self.is_tma_aligned,
|
||||
)
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
d = input.shape[-1] // 2
|
||||
output_shape = input.shape[:-1] + (d,)
|
||||
result = torch.empty(
|
||||
output_shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
if is_scale_transposed:
|
||||
scale = torch.empty(
|
||||
(d // self.group_size, input.shape[0]),
|
||||
device=input.device,
|
||||
dtype=torch.float32,
|
||||
).permute(-1, -2)
|
||||
else:
|
||||
scale = torch.empty(
|
||||
(input.shape[0], d // self.group_size),
|
||||
device=input.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
out=result,
|
||||
input=input,
|
||||
scales=scale,
|
||||
group_size=self.group_size,
|
||||
scale_ub=None,
|
||||
is_scale_transposed=is_scale_transposed,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
inps = self.get_inputs()
|
||||
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
@@ -199,6 +301,18 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
||||
pattern_silu_mul_nvfp4.register(self.patterns)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
for quant_key in [kFp8Dynamic128Sym, kFp8Dynamic64Sym]:
|
||||
for is_scale_transposed in [False, True]:
|
||||
for is_e8m0 in [True, False]:
|
||||
for is_tma_aligned in [False, True]:
|
||||
SiluMulBlockQuantPattern(
|
||||
quant_key,
|
||||
is_scale_transposed=is_scale_transposed,
|
||||
is_e8m0=is_e8m0,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
@@ -212,4 +326,5 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
ActivationQuantPattern,
|
||||
SiluMulFp8StaticQuantPattern,
|
||||
SiluMulNvfp4QuantPattern,
|
||||
SiluMulBlockQuantPattern,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user