Feature/silu block quant fusion v1 (#32996)

Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
This commit is contained in:
Monishver
2026-04-01 11:50:43 -07:00
committed by GitHub
parent c9a9db0e02
commit c09ad767cd
11 changed files with 830 additions and 9 deletions

View File

@@ -340,7 +340,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/cutlass_extensions/common.cpp")
"csrc/cutlass_extensions/common.cpp"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"

View File

@@ -0,0 +1,211 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from itertools import product
import torch
import torch.nn.functional as F
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm
import vllm._custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
@dataclass
class bench_params_t:
num_tokens: int
hidden_size: int
dtype: torch.dtype
group_size: int # Changed from list[int] to int
def description(self):
return (
f"N {self.num_tokens} "
f"x D {self.hidden_size} "
f"x DT {self.dtype} "
f"x GS {self.group_size}"
)
def get_bench_params() -> list[bench_params_t]:
"""Test configurations covering common model sizes."""
NUM_TOKENS = [16, 128, 512, 2048]
HIDDEN_SIZES = [1024, 2048, 4096, 5120, 14336] # Common FFN sizes
DTYPES = [torch.float16, torch.bfloat16]
GROUP_SIZES = [64, 128] # Changed from [[1, 64], [1, 128]]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, DTYPES, GROUP_SIZES)
bench_params = list(
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
)
return bench_params
# Reference implementations
def unfused_fp8_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int, # Changed from list[int]
):
"""Unfused: SiLU+Mul then per-tensor quantize."""
hidden = x.shape[-1] // 2
gate, up = x.split(hidden, dim=-1)
# SiLU(gate) * up
silu_out = F.silu(gate) * up
# Per-tensor quantize (no group_size used here)
silu_out, _ = ops.scaled_fp8_quant(silu_out)
def unfused_groupwise_fp8_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int, # Changed from list[int]
):
"""Unfused: SiLU+Mul then group-wise quantize."""
hidden = x.shape[-1] // 2
gate, up = x.split(hidden, dim=-1)
# SiLU(gate) * up
silu_out = F.silu(gate) * up
# Group quantize - use group_size directly
silu_out, _ = per_token_group_quant_fp8(
silu_out, group_size=group_size, use_ue8m0=False
)
def fused_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int,
):
"""Fused: SiLU+Mul+Block Quantization in single kernel."""
out, _ = ops.silu_and_mul_per_block_quant(
x,
group_size=group_size,
quant_dtype=quant_dtype,
is_scale_transposed=False,
)
# Bench functions
def bench_fn(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int,
label: str,
sub_label: str,
fn: Callable,
description: str,
) -> TMeasurement:
min_run_time = 1
globals = {
"x": x,
"quant_dtype": quant_dtype,
"group_size": group_size,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(x, quant_dtype, group_size)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
"""Run benchmarks for all implementations."""
# Make inputs: [num_tokens, hidden_size * 2] for [gate || up]
scale = 1 / params.hidden_size
x = (
torch.randn(
params.num_tokens,
params.hidden_size * 2,
dtype=params.dtype,
device="cuda",
)
* scale
)
timers = []
# Unfused per-tensor FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_fp8_impl,
"unfused_fp8_impl",
)
)
# Unfused group-wise FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_groupwise_fp8_impl,
"unfused_groupwise_fp8_impl",
)
)
# Fused group-wise FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_impl,
"fused_groupwise_fp8_impl",
)
)
return timers
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def main():
torch.set_default_device("cuda")
bench_params = get_bench_params()
print(f"Running {len(bench_params)} benchmark configurations...")
print(
f"This will take approximately {len(bench_params) * 3} seconds (1s per variant)"
)
print()
timers = []
for bp in tqdm(bench_params):
result_timers = bench(bp, "silu-mul-block-quant", bp.description())
timers.extend(result_timers)
print("\n" + "=" * 80)
print("FINAL COMPARISON - ALL RESULTS")
print("=" * 80)
print_timers(timers)
if __name__ == "__main__":
main()

View File

@@ -142,6 +142,12 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
std::optional<torch::Tensor> residual,
int64_t group_size, bool is_scale_transposed);
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);

View File

@@ -0,0 +1,169 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../dispatch_utils.h"
#include "quant_conversions.cuh"
#include "../w8a8/fp8/common.cuh"
namespace vllm {
// Logic: one thread block per (token, group) pair
template <typename scalar_t, typename scalar_out_t, bool is_scale_transposed,
int32_t group_size>
__global__ void silu_and_mul_per_block_quant_kernel(
scalar_out_t* __restrict__ out, // Output: [num_tokens, hidden_size] in
// FP8/INT8
float* __restrict__ scales, // Output: [num_tokens, hidden_size /
// group_size] or [hidden_size / group_size,
// num_tokens]
scalar_t const* __restrict__ input, // Input: [num_tokens, hidden_size * 2]
float const* scale_ub, // Optional scale upper bound
int32_t const hidden_size // Output hidden size (input is 2x this)
) {
static_assert((group_size & (group_size - 1)) == 0,
"group_size must be a power of 2 for correct reduction");
// Grid: (num_tokens, num_groups)
int const token_idx = blockIdx.x;
int const group_idx = blockIdx.y;
int const tid = threadIdx.x; // tid in [0, group_size)
int const num_tokens = gridDim.x;
// Input layout: [gate || up] concatenated along last dimension
int const input_stride = hidden_size * 2;
int const group_start = group_idx * group_size;
// Pointers to this token's data
scalar_t const* token_input_gate =
input + token_idx * input_stride + group_start;
scalar_t const* token_input_up = token_input_gate + hidden_size;
scalar_out_t* token_output = out + token_idx * hidden_size + group_start;
// Scale pointer for this group
int const num_groups = gridDim.y;
float* group_scale_ptr = is_scale_transposed
? scales + group_idx * num_tokens + token_idx
: scales + token_idx * num_groups + group_idx;
// Shared memory for reduction (compile-time sized)
__shared__ float shared_max[group_size];
// Step 1: Each thread loads one element, computes SiLU, stores in register
float gate = static_cast<float>(token_input_gate[tid]);
float up = static_cast<float>(token_input_up[tid]);
// Compute SiLU(gate) * up
float sigmoid_gate = 1.0f / (1.0f + expf(-gate));
float silu_gate = gate * sigmoid_gate;
float result = silu_gate * up; // Keep in register
// Step 2: Reduce to find group max
shared_max[tid] = fabsf(result);
__syncthreads();
// Power-of-2 reduction (group_size guaranteed to be power of 2)
#pragma unroll
for (int stride = group_size / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + stride]);
}
__syncthreads();
}
// Step 3: Compute scale (thread 0), broadcast via shared memory
if (tid == 0) {
float group_max = shared_max[0];
float const quant_range = quant_type_max_v<scalar_out_t>;
float group_scale = group_max / quant_range;
// Apply scale upper bound if provided
if (scale_ub != nullptr) {
group_scale = fminf(group_scale, *scale_ub);
}
// Use minimum safe scaling factor
group_scale = fmaxf(group_scale, min_scaling_factor<scalar_out_t>::val());
// Store scale to global memory
*group_scale_ptr = group_scale;
// Reuse shared_max[0] to broadcast scale
shared_max[0] = group_scale;
}
__syncthreads();
float group_scale = shared_max[0];
// Step 4: Quantize and write output
token_output[tid] =
vllm::ScaledQuant<scalar_out_t, false>::quant_fn(result, group_scale);
}
} // namespace vllm
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
TORCH_CHECK(
input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16,
"Input must be FP16 or BF16");
TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32");
TORCH_CHECK(group_size == 128 || group_size == 64,
"Unsupported group size: ", group_size);
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
int32_t hidden_size = out.size(-1);
auto num_tokens = input.size(0);
int32_t num_groups = hidden_size / group_size;
TORCH_CHECK(input.size(-1) == hidden_size * 2,
"input last dim must be 2x output hidden_size");
TORCH_CHECK(hidden_size % group_size == 0,
"hidden_size must be divisible by group_size");
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(num_tokens, num_groups);
dim3 block(group_size);
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_in_t = scalar_t;
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_out_t = scalar_t;
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
vllm::silu_and_mul_per_block_quant_kernel<
scalar_in_t, scalar_out_t, transpose_scale, gs>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_out_t>(),
scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr,
hidden_size);
});
});
});
});
}

View File

@@ -2,7 +2,6 @@
#include "cuda_utils.h"
#include "ops.h"
#include "core/registration.h"
#include <torch/library.h>
#include <torch/version.h>
@@ -110,6 +109,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
// Fused SiLU+Mul + per-block quantization
ops.def(
"silu_and_mul_per_block_quant("
"Tensor! out, "
"Tensor input, "
"Tensor! scales, "
"int group_size, "
"Tensor? scale_ub=None, "
"bool is_scale_transposed=False) -> ()");
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
&silu_and_mul_per_block_quant);
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);

View File

@@ -45,7 +45,7 @@ The table below lists the quantization schemes supported by each fusion on each
| `enable_sp` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — |
| `fuse_gemm_comms` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — |
| `fuse_norm_quant` | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | — | FP8 static, FP8 per-token, FP8 per-group |
| `fuse_act_quant` | FP8 static, NVFP4 | FP8 static | FP8 static | — | FP8 per-group |
| `fuse_act_quant` | FP8 static, NVFP4 | FP8 static, FP8 per-group (128/64) | FP8 static, FP8 per-group (128/64) | — | FP8 per-group |
| `fuse_act_padding` | — | — | — | — | FP16/BF16 |
\* `fuse_attn_quant` support depends on the attention backend in use; not all backends support
@@ -305,6 +305,7 @@ Note that AITER fusions are in a separate pass in `vllm.compilation.passes.fusio
Supported quantization scheme/hardware combinations:
- FP8 static per-tensor: CUDA & HIP kernel
- FP8 dynamic per-group (128/64): CUDA kernel (sm89+, not active when DeepGemm is used on sm100+)
- NVFP4 dynamic: CUDA sm100+ only with FlashInfer
- FP8 per-token-group (128): ROCm AITER only
@@ -313,6 +314,7 @@ Supported quantization scheme/hardware combinations:
- Pass: [`vllm/compilation/passes/fusion/act_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/act_quant_fusion.py)
- ROCm AITER pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py)
- CUDA/HIP kernels: [`csrc/quantization/`](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/)
- Fused SiLU+Mul+BlockQuant kernel: [`csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu`](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu)
### RMSNorm + Padding (`fuse_act_padding`)

View File

@@ -150,9 +150,8 @@ deepseek_v3_fp8 = ModelFusionInfo(
# - post_attn_layernorm + MLP
# 2 per MoE layer (remaining) due to MoE wrapping
rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers
# TODO silu+block quant
# act_quant_fusion=min(3, n_layers), # dense layers only
act_quant_fusion=0,
# silu+block quant
act_quant_fusion=min(3, n_layers), # dense layers only
# MLA attn + quant not supported yet:
# https://github.com/vllm-project/vllm/issues/35792
attn_quant_fusion=0,

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from functools import partial
import pytest
import torch
@@ -34,13 +35,16 @@ from vllm.model_executor.kernels.linear import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8Dynamic128Sym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
@@ -165,6 +169,48 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
class TestSiluMulBlockQuantModel(torch.nn.Module):
quant_key = kFp8Dynamic128Sym
def __init__(self, hidden_size: int, is_scale_transposed: bool = False, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.is_scale_transposed = is_scale_transposed
self.quant_fp8 = QuantFP8(
static=False,
group_shape=GroupShape(1, 128),
column_major_scales=is_scale_transposed,
compile_native=False,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.quant_fp8.enabled()
def forward(self, x):
y = self.silu_and_mul(x)
out, scale = self.quant_fp8(y)
group_size = self.quant_key.scale.group_shape[1]
scale_expanded = scale.repeat_interleave(group_size, dim=1)
dequant = out.to(dtype=torch.float32) * scale_expanded
return (dequant,)
def ops_in_model_before(self):
ops = []
if self.enable_silu_mul_custom_op:
ops.append(SILU_MUL_OP)
# When silu custom op is disabled, aten.mul.Tensor also appears
# in dequant code, so we skip checking it to avoid false positives.
ops.append(
QUANT_OPS[self.quant_key]
if self.enable_quant_fp8_custom_op
else torch.ops.aten.reciprocal.default
)
return ops
def ops_in_model_after(self):
return [FUSED_OPS[self.quant_key]]
ROCM_KERNELS = [ROCmFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel]
CUDA_KERNELS = [
FlashInferFP8ScaledMMLinearKernel,
@@ -200,6 +246,19 @@ TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS
not current_platform.is_rocm(), reason="ROCm only"
),
),
# Block quant fusion for per-group FP8 (CUDA only).
*[
pytest.param(
partial(TestSiluMulBlockQuantModel, is_scale_transposed=transposed),
True,
None,
marks=pytest.mark.skipif(
not current_platform.is_cuda(), reason="CUDA only"
),
id=f"TestSiluMulBlockQuant-transposed={transposed}",
)
for transposed in [False, True]
],
],
)
@pytest.mark.skipif(
@@ -213,6 +272,7 @@ def test_fusion_silu_and_mul_quant(
TestSiluMulFp8QuantModel
| TestSiluMulNvfp4QuantModel
| TestSiluMulGroupFp8QuantModel
| TestSiluMulBlockQuantModel
],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
@@ -223,6 +283,12 @@ def test_fusion_silu_and_mul_quant(
pytest.skip("NVFP4 is not supported on this GPU.")
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
if (
isinstance(model_class, partial)
and model_class.func is TestSiluMulBlockQuantModel
and is_deep_gemm_supported()
):
pytest.skip("SiluMul+BlockQuant fusion not applicable with DeepGemm")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
@@ -269,11 +335,13 @@ def test_fusion_silu_and_mul_quant(
result2 = model2(x)
# Check that it gives the same answer
if model_class == TestSiluMulFp8QuantModel:
if isinstance(model, TestSiluMulFp8QuantModel):
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
elif isinstance(model, TestSiluMulNvfp4QuantModel):
atol, rtol = 1e-1, 1e-1
elif model_class == TestSiluMulGroupFp8QuantModel:
elif isinstance(
model, (TestSiluMulGroupFp8QuantModel, TestSiluMulBlockQuantModel)
):
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8,
)
from vllm.platforms import current_platform
DTYPES = [torch.float16, torch.bfloat16]
QUANT_DTYPES = [torch.float8_e4m3fn, torch.int8]
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
NUM_TOKENS_HIDDEN_SIZES = [
*[(1, i) for i in [64, *VEC_HIDDEN_SIZES, 2048, 5120]],
*[(16, i) for i in [64, *VEC_HIDDEN_SIZES, 5120]],
*[(128, i) for i in [64, *VEC_HIDDEN_SIZES]],
*[(512, i) for i in [64, 5120]],
]
SCALE_UBS = [False]
GROUP_SIZES = [64, 128]
IS_SCALE_TRANSPOSED = [False, True]
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
def ref_silu_and_mul_per_block_quant(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Reference implementation: unfused SiLU+Mul then group quantization."""
hidden = x.shape[-1] // 2
gate, up = x.split(hidden, dim=-1)
silu_out = F.silu(gate) * up
if quant_dtype == current_platform.fp8_dtype():
return per_token_group_quant_fp8(
silu_out, group_size=group_size, use_ue8m0=False
)
elif quant_dtype == torch.int8:
return per_token_group_quant_int8(silu_out, group_size=group_size)
else:
raise ValueError(f"Unsupported quant_dtype: {quant_dtype}")
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("group_size", GROUP_SIZES)
@pytest.mark.parametrize("is_scale_transposed", IS_SCALE_TRANSPOSED)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul_per_block_quant(
default_vllm_config,
num_tokens: int,
hidden_size: int,
has_scale_ub: bool,
dtype: torch.dtype,
quant_dtype: torch.dtype,
group_size: int,
is_scale_transposed: bool,
seed: int,
device: str,
) -> None:
"""Test SiLU+Mul+Block Quantization kernel correctness."""
torch.random.manual_seed(seed)
torch.set_default_device(device)
if hidden_size % group_size != 0:
return
if has_scale_ub:
pytest.skip("Scale upper bound not yet supported")
scale = 1 / hidden_size
x = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device=device) * scale
# Reference implementation
ref_out, ref_scales = ref_silu_and_mul_per_block_quant(x, quant_dtype, group_size)
# Fused kernel implementation
ops_out, ops_scales = ops.silu_and_mul_per_block_quant(
x, group_size, quant_dtype, None, is_scale_transposed
)
# Check for NaN/Inf
assert not torch.isnan(ops_out.float()).any(), "Kernel output contains NaN"
assert not torch.isinf(ops_out.float()).any(), "Kernel output contains Inf"
assert not torch.isnan(ops_scales).any(), "Kernel scales contain NaN"
assert not torch.isinf(ops_scales).any(), "Kernel scales contain Inf"
# Check dtypes
assert ref_out.dtype == quant_dtype
assert ops_out.dtype == quant_dtype
# Check scales match
torch.testing.assert_close(ref_scales, ops_scales, rtol=1e-5, atol=1e-5)
# Check output correctness via dequantized values
ref_scales_expanded = ref_scales.repeat_interleave(group_size, dim=1)
ops_scales_expanded = ops_scales.repeat_interleave(group_size, dim=1)
ref_deq = ref_out.to(dtype=torch.float32) * ref_scales_expanded
ops_deq = ops_out.to(dtype=torch.float32) * ops_scales_expanded
torch.testing.assert_close(ref_deq, ops_deq, atol=5e-2, rtol=5e-2)
# opcheck
output = torch.empty(num_tokens, hidden_size, device=device, dtype=quant_dtype)
num_groups = hidden_size // group_size
if is_scale_transposed:
scales = torch.empty(num_groups, num_tokens, device=device, dtype=torch.float32)
else:
scales = torch.empty(num_tokens, num_groups, device=device, dtype=torch.float32)
opcheck(
torch.ops._C.silu_and_mul_per_block_quant,
(output, x, scales, group_size, None, is_scale_transposed),
)
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("num_tokens", [128])
@pytest.mark.parametrize("group_size", [128])
def test_silu_block_quant_shapes(
default_vllm_config,
dtype: torch.dtype,
hidden_size: int,
num_tokens: int,
group_size: int,
):
"""Test that output shapes are correct."""
torch.set_default_device("cuda")
x = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device="cuda")
# Row-major scales
out, scales = ops.silu_and_mul_per_block_quant(
x,
group_size=group_size,
quant_dtype=torch.float8_e4m3fn,
is_scale_transposed=False,
)
assert out.shape == (num_tokens, hidden_size)
assert scales.shape == (num_tokens, hidden_size // group_size)
# Column-major scales (logical shape same after .t() in _custom_ops)
out, scales = ops.silu_and_mul_per_block_quant(
x,
group_size=group_size,
quant_dtype=torch.float8_e4m3fn,
is_scale_transposed=True,
)
assert out.shape == (num_tokens, hidden_size)
assert scales.shape == (num_tokens, hidden_size // group_size)
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("batch_size", [1, 16, 256])
@pytest.mark.parametrize("hidden_size", [1024, 5120, 14336])
def test_silu_block_quant_edge_cases(
default_vllm_config, dtype: torch.dtype, batch_size: int, hidden_size: int
):
"""Test edge cases: single token, large batch, large hidden size."""
torch.set_default_device("cuda")
x = torch.randn(batch_size, hidden_size * 2, dtype=dtype, device="cuda")
out, scales = ops.silu_and_mul_per_block_quant(
x,
group_size=128,
quant_dtype=torch.float8_e4m3fn,
is_scale_transposed=False,
)
assert out.shape == (batch_size, hidden_size)
assert out.dtype == torch.float8_e4m3fn
assert scales.dtype == torch.float32
assert not torch.isnan(out.float()).any()
assert not torch.isnan(scales).any()
assert not torch.isinf(scales).any()

View File

@@ -579,6 +579,56 @@ def rms_norm_per_block_quant(
return output, scales
# fused silu_and_mul + block quant
def silu_and_mul_per_block_quant(
input: torch.Tensor,
group_size: int, # Changed from list[int]
quant_dtype: torch.dtype,
scale_ub: torch.Tensor | None = None,
is_scale_transposed: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert input.ndim == 2, f"input must be 2D [batch, hidden*2], got {input.shape}"
assert input.shape[-1] % 2 == 0, (
f"input last dim must be even (gate||up layout), got {input.shape[-1]}"
)
# Output is half the width of input (after silu_and_mul)
num_tokens = input.shape[0]
hidden_size = input.shape[-1] // 2 # Divide by 2 because input is [gate || up]
# Allocate output tensor (FP8 or INT8)
output = torch.empty(
(num_tokens, hidden_size), device=input.device, dtype=quant_dtype
)
# Allocate scales tensor
num_groups = hidden_size // group_size # Directly use group_size
if is_scale_transposed:
scales = torch.empty(
(num_groups, num_tokens),
device=input.device,
dtype=torch.float32,
).t()
else:
scales = torch.empty(
(num_tokens, num_groups),
device=input.device,
dtype=torch.float32,
)
# Call the C++ kernel
torch.ops._C.silu_and_mul_per_block_quant(
output,
input,
scales,
group_size, # Pass directly as int
scale_ub,
is_scale_transposed,
)
return output, scales
# quantization ops
# awq
def awq_dequantize(

View File

@@ -17,6 +17,8 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
@@ -43,6 +45,10 @@ silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
if current_platform.is_cuda():
FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
class ActivationQuantPattern(ABC):
"""
@@ -174,6 +180,102 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
class SiluMulBlockQuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+BlockQuant (FP8 dynamic per-group) Pattern.
Supports group_size 128 and 64 via QuantKey.
Parameterized on is_scale_transposed for different scale layouts.
"""
def __init__(
self,
quant_key: QuantKey,
is_scale_transposed: bool = False,
is_e8m0: bool = False,
is_tma_aligned: bool = False,
) -> None:
super().__init__(quant_key)
self.quant_matcher = MatcherQuantFP8(
quant_key,
has_col_major_scales=is_scale_transposed,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
)
self.group_size = quant_key.scale.group_shape[1]
self.is_scale_transposed = is_scale_transposed
self.is_e8m0 = is_e8m0
self.is_tma_aligned = is_tma_aligned
def get_inputs(self) -> list[torch.Tensor]:
scale = self.quant_matcher.empty_f32(1, 1)
return self.silu_and_mul_matcher.inputs() + [scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
is_scale_transposed = self.is_scale_transposed
def pattern(
input: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
silu_out = self.silu_and_mul_matcher(input)
result = torch.empty(
silu_out.shape,
device=silu_out.device,
dtype=self.quant_dtype,
)
assert scale is not None
finfo = torch.finfo(self.quant_dtype)
_, result, scale = auto_functionalized(
self.quant_matcher.QUANT_OP,
input=silu_out,
output_q=result,
output_s=scale,
group_size=self.group_size,
eps=1e-10,
fp8_min=finfo.min,
fp8_max=finfo.max,
scale_ue8m0=self.is_e8m0,
dummy_is_scale_transposed=is_scale_transposed,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, scale
def replacement(
input: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
d = input.shape[-1] // 2
output_shape = input.shape[:-1] + (d,)
result = torch.empty(
output_shape, device=input.device, dtype=self.quant_dtype
)
if is_scale_transposed:
scale = torch.empty(
(d // self.group_size, input.shape[0]),
device=input.device,
dtype=torch.float32,
).permute(-1, -2)
else:
scale = torch.empty(
(input.shape[0], d // self.group_size),
device=input.device,
dtype=torch.float32,
)
at = auto_functionalized(
self.FUSED_OP,
out=result,
input=input,
scales=scale,
group_size=self.group_size,
scale_ub=None,
is_scale_transposed=is_scale_transposed,
)
return at[1], at[2]
inps = self.get_inputs()
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
@@ -199,6 +301,18 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)
if current_platform.is_cuda():
for quant_key in [kFp8Dynamic128Sym, kFp8Dynamic64Sym]:
for is_scale_transposed in [False, True]:
for is_e8m0 in [True, False]:
for is_tma_aligned in [False, True]:
SiluMulBlockQuantPattern(
quant_key,
is_scale_transposed=is_scale_transposed,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
@@ -212,4 +326,5 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern,
SiluMulBlockQuantPattern,
)