[Perf][Kernel] Optimize FP4 quantization kernels (SM100F) (#32520)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
Roberto L. Castro
2026-01-25 02:45:27 +01:00
committed by GitHub
parent 1ebdff412a
commit fcb9df99bd
18 changed files with 508 additions and 151 deletions

View File

@@ -20,8 +20,12 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
PROVIDER_CFGS = {
"vllm": dict(backend="vllm", enabled=True),
"flashinfer": dict(backend="flashinfer", enabled=True),
"vllm": dict(backend="vllm", is_sf_swizzled_layout=False, enabled=True),
"vllm-swizzle": dict(backend="vllm", is_sf_swizzled_layout=True, enabled=True),
"flashinfer": dict(backend="flashinfer", is_sf_swizzled_layout=False, enabled=True),
"flashinfer-swizzle": dict(
backend="flashinfer", is_sf_swizzled_layout=True, enabled=True
),
}
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
@@ -36,7 +40,7 @@ def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor:
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192],
x_log=False,
line_arg="provider",
line_vals=_enabled,
@@ -63,19 +67,36 @@ def benchmark(batch_size, provider, N, K):
if cfg["backend"] == "vllm":
# vLLM's FP4 quantization
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: ops.scaled_fp4_quant(a, a_global_scale),
quantiles=quantiles,
)
if cfg["is_sf_swizzled_layout"]:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: ops.scaled_fp4_quant(
a, a_global_scale, is_sf_swizzled_layout=True
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: ops.scaled_fp4_quant(
a, a_global_scale, is_sf_swizzled_layout=False
),
quantiles=quantiles,
)
elif cfg["backend"] == "flashinfer":
# FlashInfer's FP4 quantization
# Use is_sf_swizzled_layout=True to match vLLM's output format
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: flashinfer_fp4_quantize(
a, a_global_scale, is_sf_swizzled_layout=True
),
quantiles=quantiles,
)
if cfg["is_sf_swizzled_layout"]:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: flashinfer_fp4_quantize(
a, a_global_scale, is_sf_swizzled_layout=True
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: flashinfer_fp4_quantize(
a, a_global_scale, is_sf_swizzled_layout=False
),
quantiles=quantiles,
)
# Convert ms to us for better readability at small batch sizes
to_us = lambda t_ms: t_ms * 1000
@@ -92,7 +113,9 @@ def prepare_shapes(args):
return out
def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
def _test_accuracy_once(
M: int, K: int, dtype: torch.dtype, device: str, is_sf_swizzled_layout: bool
):
"""Test accuracy between vLLM and FlashInfer FP4 quantization."""
# Create input tensor
a = torch.randn((M, K), device=device, dtype=dtype)
@@ -101,11 +124,13 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
a_global_scale = compute_global_scale(a)
# vLLM quantization
vllm_fp4, vllm_scale = ops.scaled_fp4_quant(a, a_global_scale)
vllm_fp4, vllm_scale = ops.scaled_fp4_quant(
a, a_global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
)
# FlashInfer quantization (with swizzled layout to match vLLM's output)
flashinfer_fp4, flashinfer_scale = flashinfer_fp4_quantize(
a, a_global_scale, is_sf_swizzled_layout=True
a, a_global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
)
flashinfer_scale = flashinfer_scale.view(torch.float8_e4m3fn)
@@ -114,7 +139,14 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
vllm_fp4,
flashinfer_fp4,
)
print(f"M={M}, K={K}, dtype={dtype}: PASSED")
# Compare scales
torch.testing.assert_close(
vllm_scale,
flashinfer_scale,
)
print(
f"M={M}, K={K}, dtype={dtype}, is_sf_swizzled_layout={is_sf_swizzled_layout}: PASSED" # noqa: E501
)
def test_accuracy():
@@ -130,9 +162,10 @@ def test_accuracy():
Ms = [1, 1024]
Ks = [4096]
for M in Ms:
for K in Ks:
_test_accuracy_once(M, K, dtype, device)
for is_sf_swizzled_layout in [True, False]:
for M in Ms:
for K in Ks:
_test_accuracy_once(M, K, dtype, device, is_sf_swizzled_layout)
print("\nAll accuracy tests passed!")
@@ -145,7 +178,7 @@ if __name__ == "__main__":
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
default=["meta-llama/Llama-3.3-70B-Instruct"],
choices=list(WEIGHT_SHAPES.keys()),
)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])

View File

@@ -293,7 +293,8 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_scale,
torch::Tensor const& input_scale);
torch::Tensor const& input_scale,
bool is_sf_swizzled_layout);
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,

View File

@@ -27,17 +27,24 @@
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
// Define before including nvfp4_utils.cuh so the header
// can use this macro during compilation.
#define NVFP4_ENABLE_ELTS16 1
#include "nvfp4_utils.cuh"
namespace vllm {
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
float const* SFScale, uint32_t* out,
uint32_t* SFout) {
using PackedVec = PackedVec<Type>;
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols,
int32_t num_padded_cols,
Type const* __restrict__ in,
float const* __restrict__ SFScale,
uint32_t* __restrict__ out,
uint32_t* __restrict__ SFout) {
using PackedVec = vllm::PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
@@ -49,34 +56,60 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
float const SFScaleVal = (SFScale == nullptr) ? 1.0f : SFScale[0];
int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x;
int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
colIdx += blockDim.x) {
if (colIdx < num_padded_cols) {
PackedVec in_vec;
PackedVec in_vec2;
int64_t inOffset =
rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx;
int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) +
numCols / CVT_FP4_ELTS_PER_THREAD + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
PackedVec in_vec2 = reinterpret_cast<PackedVec const*>(in)[inOffset2];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
auto& out_pos = out[outOffset];
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
if constexpr (CVT_FP4_PACK16) {
ld256_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
valid);
ld256_or_zero_cg_u32<Type>(
in_vec2, &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 8],
valid);
} else {
ld128_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
valid);
ld128_or_zero_cg_u32<Type>(
in_vec2, &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 4],
valid);
}
// Compute silu and mul
PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2);
PackedVec out_silu_mul = compute_silu_mul<Type>(in_vec, in_vec2);
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx, colIdx, numKTiles, SFout);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal,
sf_out);
auto out_val =
cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
out_silu_mul, SFScaleVal, sf_out);
if (valid) {
if constexpr (CVT_FP4_PACK16) {
int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2;
uint64_t packed64 =
(uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
} else {
out[inOffset] = out_val;
}
}
}
}
}
@@ -103,17 +136,23 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024));
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x));
int grid_x = std::min(
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
m, n, input_ptr, input_sf_ptr,
m, n, sf_n_unpadded, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});

View File

@@ -140,8 +140,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
out_pos = cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
quant_input, SFScaleVal, sf_out);
}
}
@@ -246,8 +246,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
out_pos = cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
quant_input, SFScaleVal, sf_out);
}
}

View File

@@ -21,7 +21,8 @@
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf);
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
@@ -51,10 +52,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
#endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf) {
torch::Tensor& output_sf, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf);
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
}

View File

@@ -27,29 +27,23 @@
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
// Define before including nvfp4_utils.cuh so the header
// can use this macro during compilation.
#define NVFP4_ENABLE_ELTS16 1
#include "nvfp4_utils.cuh"
namespace vllm {
template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
static_assert(std::is_integral_v<Int>,
"round_up argument must be integral type");
return ((x + y - 1) / y) * y;
}
// Compute effective rows for grid configuration with swizzled SF layouts.
inline int computeEffectiveRows(int m) {
constexpr int ROW_TILE = 128;
return round_up(m, ROW_TILE);
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
float const* SFScale, uint32_t* out, uint32_t* SFout) {
using PackedVec = PackedVec<Type>;
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, int32_t num_padded_cols,
Type const* __restrict__ in,
float const* __restrict__ SFScale,
uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) {
using PackedVec = vllm::PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
@@ -59,33 +53,31 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
int32_t const numKTiles = (numCols + 63) / 64;
int sf_m = round_up<int>(numRows, 128);
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
int num_padded_cols = sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE;
int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x;
int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0];
float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0];
// Iterate over all rows and cols including padded ones -
// ensures we visit every single scale factor address to initialize it.
for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x;
colIdx < num_padded_cols / CVT_FP4_ELTS_PER_THREAD;
colIdx += blockDim.x) {
int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
if (colIdx < num_padded_cols) {
PackedVec in_vec;
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
// If we are outside valid rows OR outside valid columns -> Use Zeros
if (rowIdx >= numRows || elem_idx >= numCols) {
memset(&in_vec, 0, sizeof(PackedVec));
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
if constexpr (CVT_FP4_PACK16) {
ld256_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
valid);
} else {
// Valid Region: Load actual data
in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
ld128_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
valid);
}
auto sf_out =
@@ -94,13 +86,85 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
rowIdx, colIdx, numKTiles, SFout);
auto out_val =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
in_vec, global_scale, sf_out);
// We do NOT write output for padding because the 'out' tensor is not
// padded.
if (rowIdx < numRows && elem_idx < numCols) {
// Same as inOffset because 8 elements are packed into one uint32_t.
out[inOffset] = out_val;
if (valid) {
if constexpr (CVT_FP4_PACK16) {
int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2;
uint64_t packed64 =
(uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
} else {
out[inOffset] = out_val;
}
}
}
}
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols,
int32_t sf_n_unpadded, Type const* __restrict__ in,
float const* __restrict__ SFScale,
uint32_t* __restrict__ out,
uint32_t* __restrict__ SFout) {
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x;
int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0];
// Iterate over all rows and cols including padded ones -
// ensures we visit every single scale factor address to initialize it.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
if (colIdx < sf_n_unpadded) {
PackedVec in_vec;
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
// If we are outside valid rows OR outside valid columns -> Use Zeros
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
if constexpr (CVT_FP4_PACK16) {
ld256_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
valid);
} else {
ld128_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
valid);
}
auto sf_out =
sf_out_rowmajor_u8<uint32_t>(rowIdx, colIdx, sf_n_unpadded, SFout);
auto out_val =
cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
in_vec, global_scale, sf_out);
// We do NOT write output for padding because the 'out' tensor is not
// padded.
if (valid) {
if constexpr (CVT_FP4_PACK16) {
int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2;
uint64_t packed64 =
(uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
} else {
out[inOffset] = out_val;
}
}
}
}
@@ -111,7 +175,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf) {
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int32_t m = input.size(0);
int32_t n = input.size(1);
@@ -129,19 +194,48 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
// Grid, Block size. Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
int effectiveRows = vllm::computeEffectiveRows(m);
dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, input_ptr, input_sf_ptr, reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
}
if (is_sf_swizzled_layout) {
int sf_n_int = int(vllm::round_up(sf_n_unpadded, 4) / 4);
int32_t num_padded_cols =
sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
int grid_y = vllm::div_round_up(num_padded_cols, static_cast<int>(block.x));
int grid_x =
std::min(vllm::computeEffectiveRows(m),
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
} else {
int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x));
int grid_x = std::min(
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, input_ptr,
input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
}
}

View File

@@ -19,9 +19,17 @@
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#define ELTS_PER_THREAD 8
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
#define ELTS_PER_THREAD 16
constexpr int CVT_FP4_ELTS_PER_THREAD = 16;
constexpr bool CVT_FP4_PACK16 = true;
#else
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr bool CVT_FP4_PACK16 = false;
#endif
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
namespace vllm {
@@ -68,19 +76,46 @@ struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
// Define a 32 bytes packed data type.
template <class Type>
struct alignas(32) PackedVec {
typename TypeConverter<Type>::Type elts[8];
};
#else
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
struct alignas(16) PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
#endif
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
static_assert(std::is_integral_v<Int>,
"round_up argument must be integral type");
return ((x + y - 1) / y) * y;
}
template <typename Int>
__host__ __device__ __forceinline__ Int div_round_up(Int x, Int y) {
return (x + y - 1) / y;
}
// Compute effective rows for grid configuration with swizzled SF layouts.
inline int computeEffectiveRows(int m) {
constexpr int ROW_TILE = 128;
return round_up(m, ROW_TILE);
}
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
uint32_t val;
asm volatile(
"{\n"
@@ -101,7 +136,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
__device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) {
uint32_t val;
asm volatile(
"{\n"
@@ -114,20 +149,115 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
"}\n"
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
}
struct u32x2 {
uint32_t lo, hi;
};
using fp4_packed_t = std::conditional_t<CVT_FP4_PACK16, u32x2, uint32_t>;
__device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) {
u32x2 out;
asm volatile(
"{\n"
".reg .b8 b0;\n"
".reg .b8 b1;\n"
".reg .b8 b2;\n"
".reg .b8 b3;\n"
".reg .b8 b4;\n"
".reg .b8 b5;\n"
".reg .b8 b6;\n"
".reg .b8 b7;\n"
"cvt.rn.satfinite.e2m1x2.f32 b0, %3, %2;\n"
"cvt.rn.satfinite.e2m1x2.f32 b1, %5, %4;\n"
"cvt.rn.satfinite.e2m1x2.f32 b2, %7, %6;\n"
"cvt.rn.satfinite.e2m1x2.f32 b3, %9, %8;\n"
"cvt.rn.satfinite.e2m1x2.f32 b4, %11, %10;\n"
"cvt.rn.satfinite.e2m1x2.f32 b5, %13, %12;\n"
"cvt.rn.satfinite.e2m1x2.f32 b6, %15, %14;\n"
"cvt.rn.satfinite.e2m1x2.f32 b7, %17, %16;\n"
"mov.b32 %0, {b0, b1, b2, b3};\n"
"mov.b32 %1, {b4, b5, b6, b7};\n"
"}\n"
: "=r"(out.lo), "=r"(out.hi)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y),
"f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y),
"f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y));
return out;
}
__device__ __forceinline__ uint32_t pack_fp4(float2 (&v)[4]) {
return fp32_vec8_to_e2m1(v);
}
__device__ __forceinline__ u32x2 pack_fp4(float2 (&v)[8]) {
return fp32_vec16_to_e2m1(v);
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
__device__ __forceinline__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(b) : "f"(a));
return b;
}
template <class Type>
__device__ __forceinline__ void ld128_or_zero_cg_u32(PackedVec<Type>& out,
const void* ptr,
bool pred) {
uint32_t r0, r1, r2, r3;
asm volatile(
"{\n"
" .reg .pred pr;\n"
" setp.ne.u32 pr, %4, 0;\n"
" mov.u32 %0, 0;\n"
" mov.u32 %1, 0;\n"
" mov.u32 %2, 0;\n"
" mov.u32 %3, 0;\n"
" @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n"
"}\n"
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3)
: "r"((int)pred), "l"(ptr));
*reinterpret_cast<uint4*>(&out) = uint4{r0, r1, r2, r3};
}
template <class Type>
__device__ __forceinline__ void ld256_or_zero_cg_u32(PackedVec<Type>& out,
const void* ptr,
bool pred) {
uint32_t r0, r1, r2, r3, r4, r5, r6, r7;
asm volatile(
"{\n"
" .reg .pred pr;\n"
" setp.ne.u32 pr, %8, 0;\n"
" mov.u32 %0, 0;\n"
" mov.u32 %1, 0;\n"
" mov.u32 %2, 0;\n"
" mov.u32 %3, 0;\n"
" mov.u32 %4, 0;\n"
" mov.u32 %5, 0;\n"
" mov.u32 %6, 0;\n"
" mov.u32 %7, 0;\n"
" @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n"
"}\n"
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4), "=r"(r5), "=r"(r6),
"=r"(r7)
: "r"((int)pred), "l"(ptr));
reinterpret_cast<uint4*>(&out)[0] = uint4{r0, r1, r2, r3};
reinterpret_cast<uint4*>(&out)[1] = uint4{r4, r5, r6, r7};
}
// Compute SF output offset for swizzled tensor core layout.
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
// Caller must precompute: numKTiles = (numCols + 63) / 64
@@ -166,21 +296,41 @@ __device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
template <class SFType>
__device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack,
int packs_per_row_sf,
SFType* SFout) {
constexpr int PACK = CVT_FP4_ELTS_PER_THREAD;
constexpr int THREADS_PER_SF =
CVT_FP4_SF_VEC_SIZE / PACK; // 1 if PACK=16, 2 else PACK=8
if (threadIdx.x % THREADS_PER_SF != 0) return nullptr;
int sf_col =
pack / THREADS_PER_SF; // PACK=16 => sf_col=pack; PACK=8 => sf_col=pack/2
int64_t off = (int64_t)row * packs_per_row_sf + sf_col;
return (uint8_t*)SFout + off;
}
// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
uint8_t* SFout) {
template <class Type, int CVT_FP4_NUM_THREADS_PER_SF, bool UE8M0_SF = false>
__device__ __forceinline__ fp4_packed_t
cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
if constexpr (CVT_FP4_NUM_THREADS_PER_SF == 2) {
localMax = __hmax2(__shfl_xor_sync(0xffffffffu, localMax, 1), localMax);
}
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
@@ -205,18 +355,17 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
// Convert back to fp32.
SFValue = float(tmp);
}
// Write the SF to global memory (STG.8).
if (SFout) *SFout = fp8SFVal;
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
SFValue != 0.0f ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
@@ -233,10 +382,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
return pack_fp4(fp2Vals);
}
// silu in float32

View File

@@ -546,7 +546,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor! output, Tensor input,"
" Tensor! output_scale, Tensor input_scale) -> ()");
" Tensor! output_scale, Tensor input_scale, bool "
"is_sf_swizzled_layout) -> ()");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
// Compute NVFP4 experts quantization.

View File

@@ -107,10 +107,14 @@ def test_flashinfer_nvfp4_gemm(
# from checkpoints are in linear scales.
# So instead of needing to swizzle for cutlass as in modelopt.py,
# we need to unswizzle for trtllm here.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale, backend)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
a_dtype, a_global_scale, is_sf_swizzled_layout=True, backend=backend
)
is_sf_128x4_layout = not (backend == "trtllm" and m <= 32)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(
b_dtype, b_global_scale, is_sf_swizzled_layout=True
)
# get_ref_results unswizzles the scales internally.
expected_out = get_ref_results(

View File

@@ -27,6 +27,12 @@ PAD_SHAPES = [
(150, 128),
(150, 48),
(90, 80),
(128, 512),
(128, 1024),
(128, 2048),
(64, 7168),
(64, 7152),
(32, 14336),
]
SEEDS = [42]
CUDA_DEVICES = ["cuda:0"]
@@ -173,3 +179,25 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
out_ans = cast_from_fp4(out, m, n)
torch.testing.assert_close(out_ans, out_ref)
torch.testing.assert_close(scale_ans, scale_ref)
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
@torch.inference_mode()
def test_quantize_to_fp4_padded_no_sf_swizzled(pad_shape: tuple[int, int]) -> None:
dtype = torch.float16
set_random_seed(42)
torch.set_default_device("cuda:0")
m, n = pad_shape
x = torch.randn((m, n), dtype=dtype)
tensor_amax = torch.abs(x).max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
out, out_scale = ops.scaled_fp4_quant(x, global_scale, is_sf_swizzled_layout=False)
scale_ans = out_scale.to(torch.float32)
out_ans = cast_from_fp4(out, m, n)
torch.testing.assert_close(out_ans, out_ref)
torch.testing.assert_close(scale_ans, scale_ref)

View File

@@ -1534,6 +1534,7 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
def scaled_fp4_quant(
input: torch.Tensor,
input_global_scale: torch.Tensor,
is_sf_swizzled_layout: bool = True,
backend: str = "none",
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@@ -1577,22 +1578,26 @@ def scaled_fp4_quant(
else:
# Two fp4 values will be packed into an uint8.
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
if is_sf_swizzled_layout:
# We use the rounded values to store the swizzled values. Due to the
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
else:
output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8)
# We use the rounded values to store the swizzled values. Due to the
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
torch.ops._C.scaled_fp4_quant(
output, input, output_scale, input_global_scale, is_sf_swizzled_layout
)
torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale)
output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale

View File

@@ -152,6 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
input=result_silu_mul,
output_scale=output_scale,
input_scale=scale,
is_sf_swizzled_layout=True,
)
return at[1], at[2]

View File

@@ -946,6 +946,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
is_sf_swizzled_layout=True,
)
# quant_out, allreduce_output, output_scale
@@ -1043,6 +1044,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
is_sf_swizzled_layout=True,
)
# quant_out, allreduce_output, output_scale

View File

@@ -248,6 +248,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale,
is_sf_swizzled_layout=True,
)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view

View File

@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize,
)
from vllm.triton_utils import tl, triton
from vllm.utils.flashinfer import flashinfer_fp4_quantize
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -117,9 +116,7 @@ def _nvfp4_quantize(
A_scale: torch.Tensor | None,
is_sf_swizzled_layout: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return flashinfer_fp4_quantize(
A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
)
return ops.scaled_fp4_quant(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout)
def _fp8_quantize(

View File

@@ -191,7 +191,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(
x, layer.input_global_scale, self.backend
x,
layer.input_global_scale,
is_sf_swizzled_layout=True,
backend=self.backend,
)
mm_args = (

View File

@@ -1307,7 +1307,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv, self.backend)
x_fp4, x_blockscale = scaled_fp4_quant(
x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend
)
# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale

View File

@@ -8,6 +8,7 @@ import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
@@ -341,10 +342,8 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
layer.a1_gscale,
is_sf_swizzled_layout=False,
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Determine routing method type
@@ -443,10 +442,8 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
layer.a1_gscale,
is_sf_swizzled_layout=False,
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Call TRT-LLM FP4 block-scale MoE kernel