[ROCm] Enable fused_silu_mul_block_quant on ROCm (#38817)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
d74a306c4b
commit
56c976c1b5
@@ -299,6 +299,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/w8a8/int8/scaled_quant.cu"
|
||||
"csrc/quantization/w8a8/fp8/common.cu"
|
||||
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
@@ -340,8 +341,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu")
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
|
||||
@@ -143,13 +143,11 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
std::optional<torch::Tensor> residual,
|
||||
int64_t group_size, bool is_scale_transposed);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
bool is_scale_transposed);
|
||||
#endif
|
||||
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
std::optional<torch::Tensor> key, int64_t head_size,
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
#include "libtorch_stable/quantization/vectorization.cuh"
|
||||
// TODO(luka/varun):refactor common.cuh to use this file instead
|
||||
#include "quantization/w8a8/fp8/common.cuh"
|
||||
#include "../w8a8/fp8/common.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "libtorch_stable/quantization/vectorization.cuh"
|
||||
#include "quantization/utils.cuh"
|
||||
#include "../../utils.cuh"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
|
||||
@@ -110,6 +110,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||
|
||||
// Fused SiLU+Mul + per-block quantization
|
||||
ops.def(
|
||||
"silu_and_mul_per_block_quant("
|
||||
"Tensor! out, "
|
||||
"Tensor input, "
|
||||
"Tensor! scales, "
|
||||
"int group_size, "
|
||||
"Tensor? scale_ub=None, "
|
||||
"bool is_scale_transposed=False) -> ()");
|
||||
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
|
||||
&silu_and_mul_per_block_quant);
|
||||
|
||||
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
||||
|
||||
@@ -233,17 +245,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
// Fused SiLU+Mul + per-block quantization
|
||||
ops.def(
|
||||
"silu_and_mul_per_block_quant("
|
||||
"Tensor! out, "
|
||||
"Tensor input, "
|
||||
"Tensor! scales, "
|
||||
"int group_size, "
|
||||
"Tensor? scale_ub=None, "
|
||||
"bool is_scale_transposed=False) -> ()");
|
||||
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
|
||||
&silu_and_mul_per_block_quant);
|
||||
// DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
|
||||
ops.def(
|
||||
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
|
||||
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
QUANT_DTYPES = [torch.float8_e4m3fn, torch.int8]
|
||||
QUANT_DTYPES = [current_platform.fp8_dtype(), torch.int8]
|
||||
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
|
||||
NUM_TOKENS_HIDDEN_SIZES = [
|
||||
*[(1, i) for i in [64, *VEC_HIDDEN_SIZES, 2048, 5120]],
|
||||
@@ -28,9 +28,7 @@ SCALE_UBS = [False]
|
||||
GROUP_SIZES = [64, 128]
|
||||
IS_SCALE_TRANSPOSED = [False, True]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [i for i in range(1 if torch.accelerator.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
def ref_silu_and_mul_per_block_quant(
|
||||
@@ -60,7 +58,7 @@ def ref_silu_and_mul_per_block_quant(
|
||||
@pytest.mark.parametrize("group_size", GROUP_SIZES)
|
||||
@pytest.mark.parametrize("is_scale_transposed", IS_SCALE_TRANSPOSED)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device_idx", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_and_mul_per_block_quant(
|
||||
default_vllm_config,
|
||||
@@ -72,9 +70,11 @@ def test_silu_and_mul_per_block_quant(
|
||||
group_size: int,
|
||||
is_scale_transposed: bool,
|
||||
seed: int,
|
||||
device: str,
|
||||
device_idx: str,
|
||||
) -> None:
|
||||
"""Test SiLU+Mul+Block Quantization kernel correctness."""
|
||||
torch.accelerator.set_device_index(device_idx)
|
||||
device = f"cuda:{device_idx}"
|
||||
torch.random.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@@ -147,7 +147,7 @@ def test_silu_block_quant_shapes(
|
||||
out, scales = ops.silu_and_mul_per_block_quant(
|
||||
x,
|
||||
group_size=group_size,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
quant_dtype=current_platform.fp8_dtype(),
|
||||
is_scale_transposed=False,
|
||||
)
|
||||
assert out.shape == (num_tokens, hidden_size)
|
||||
@@ -157,7 +157,7 @@ def test_silu_block_quant_shapes(
|
||||
out, scales = ops.silu_and_mul_per_block_quant(
|
||||
x,
|
||||
group_size=group_size,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
quant_dtype=current_platform.fp8_dtype(),
|
||||
is_scale_transposed=True,
|
||||
)
|
||||
assert out.shape == (num_tokens, hidden_size)
|
||||
@@ -177,12 +177,12 @@ def test_silu_block_quant_edge_cases(
|
||||
out, scales = ops.silu_and_mul_per_block_quant(
|
||||
x,
|
||||
group_size=128,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
quant_dtype=current_platform.fp8_dtype(),
|
||||
is_scale_transposed=False,
|
||||
)
|
||||
|
||||
assert out.shape == (batch_size, hidden_size)
|
||||
assert out.dtype == torch.float8_e4m3fn
|
||||
assert out.dtype == current_platform.fp8_dtype()
|
||||
assert scales.dtype == torch.float32
|
||||
assert not torch.isnan(out.float()).any()
|
||||
assert not torch.isnan(scales).any()
|
||||
|
||||
@@ -45,7 +45,7 @@ silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
|
||||
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda_alike():
|
||||
FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
|
||||
FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
|
||||
|
||||
@@ -301,7 +301,7 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
||||
pattern_silu_mul_nvfp4.register(self.patterns)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda_alike():
|
||||
for quant_key in [kFp8Dynamic128Sym, kFp8Dynamic64Sym]:
|
||||
for is_scale_transposed in [False, True]:
|
||||
for is_e8m0 in [True, False]:
|
||||
|
||||
Reference in New Issue
Block a user