From fcb9df99bd7d0e532bcf2891db2a85bd927605fe Mon Sep 17 00:00:00 2001 From: "Roberto L. Castro" <38211239+LopezCastroRoberto@users.noreply.github.com> Date: Sun, 25 Jan 2026 02:45:27 +0100 Subject: [PATCH] [Perf][Kernel] Optimize FP4 quantization kernels (SM100F) (#32520) Signed-off-by: LopezCastroRoberto --- benchmarks/kernels/bench_nvfp4_quant.py | 77 +++++-- csrc/ops.h | 3 +- .../activation_nvfp4_quant_fusion_kernels.cu | 79 +++++-- csrc/quantization/fp4/nvfp4_experts_quant.cu | 8 +- csrc/quantization/fp4/nvfp4_quant_entry.cu | 9 +- csrc/quantization/fp4/nvfp4_quant_kernels.cu | 186 +++++++++++++---- csrc/quantization/fp4/nvfp4_utils.cuh | 196 +++++++++++++++--- csrc/torch_bindings.cpp | 3 +- .../test_flashinfer_nvfp4_scaled_mm.py | 8 +- .../kernels/quantization/test_nvfp4_quant.py | 28 +++ vllm/_custom_ops.py | 31 +-- vllm/compilation/activation_quant_fusion.py | 1 + vllm/compilation/collective_fusion.py | 2 + vllm/compilation/fusion_attn.py | 1 + vllm/model_executor/layers/fused_moe/utils.py | 5 +- .../schemes/compressed_tensors_w4a4_nvfp4.py | 5 +- .../layers/quantization/modelopt.py | 4 +- .../quantization/utils/flashinfer_fp4_moe.py | 13 +- 18 files changed, 508 insertions(+), 151 deletions(-) diff --git a/benchmarks/kernels/bench_nvfp4_quant.py b/benchmarks/kernels/bench_nvfp4_quant.py index 751737653..c48353820 100644 --- a/benchmarks/kernels/bench_nvfp4_quant.py +++ b/benchmarks/kernels/bench_nvfp4_quant.py @@ -20,8 +20,12 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max PROVIDER_CFGS = { - "vllm": dict(backend="vllm", enabled=True), - "flashinfer": dict(backend="flashinfer", enabled=True), + "vllm": dict(backend="vllm", is_sf_swizzled_layout=False, enabled=True), + "vllm-swizzle": dict(backend="vllm", is_sf_swizzled_layout=True, enabled=True), + "flashinfer": dict(backend="flashinfer", is_sf_swizzled_layout=False, enabled=True), + "flashinfer-swizzle": dict( + backend="flashinfer", is_sf_swizzled_layout=True, enabled=True + ), } _enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] @@ -36,7 +40,7 @@ def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor: @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], - x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], + x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192], x_log=False, line_arg="provider", line_vals=_enabled, @@ -63,19 +67,36 @@ def benchmark(batch_size, provider, N, K): if cfg["backend"] == "vllm": # vLLM's FP4 quantization - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: ops.scaled_fp4_quant(a, a_global_scale), - quantiles=quantiles, - ) + if cfg["is_sf_swizzled_layout"]: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: ops.scaled_fp4_quant( + a, a_global_scale, is_sf_swizzled_layout=True + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: ops.scaled_fp4_quant( + a, a_global_scale, is_sf_swizzled_layout=False + ), + quantiles=quantiles, + ) elif cfg["backend"] == "flashinfer": # FlashInfer's FP4 quantization - # Use is_sf_swizzled_layout=True to match vLLM's output format - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: flashinfer_fp4_quantize( - a, a_global_scale, is_sf_swizzled_layout=True - ), - quantiles=quantiles, - ) + if cfg["is_sf_swizzled_layout"]: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: flashinfer_fp4_quantize( + a, a_global_scale, is_sf_swizzled_layout=True + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: flashinfer_fp4_quantize( + a, a_global_scale, is_sf_swizzled_layout=False + ), + quantiles=quantiles, + ) # Convert ms to us for better readability at small batch sizes to_us = lambda t_ms: t_ms * 1000 @@ -92,7 +113,9 @@ def prepare_shapes(args): return out -def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str): +def _test_accuracy_once( + M: int, K: int, dtype: torch.dtype, device: str, is_sf_swizzled_layout: bool +): """Test accuracy between vLLM and FlashInfer FP4 quantization.""" # Create input tensor a = torch.randn((M, K), device=device, dtype=dtype) @@ -101,11 +124,13 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str): a_global_scale = compute_global_scale(a) # vLLM quantization - vllm_fp4, vllm_scale = ops.scaled_fp4_quant(a, a_global_scale) + vllm_fp4, vllm_scale = ops.scaled_fp4_quant( + a, a_global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout + ) # FlashInfer quantization (with swizzled layout to match vLLM's output) flashinfer_fp4, flashinfer_scale = flashinfer_fp4_quantize( - a, a_global_scale, is_sf_swizzled_layout=True + a, a_global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout ) flashinfer_scale = flashinfer_scale.view(torch.float8_e4m3fn) @@ -114,7 +139,14 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str): vllm_fp4, flashinfer_fp4, ) - print(f"M={M}, K={K}, dtype={dtype}: PASSED") + # Compare scales + torch.testing.assert_close( + vllm_scale, + flashinfer_scale, + ) + print( + f"M={M}, K={K}, dtype={dtype}, is_sf_swizzled_layout={is_sf_swizzled_layout}: PASSED" # noqa: E501 + ) def test_accuracy(): @@ -130,9 +162,10 @@ def test_accuracy(): Ms = [1, 1024] Ks = [4096] - for M in Ms: - for K in Ks: - _test_accuracy_once(M, K, dtype, device) + for is_sf_swizzled_layout in [True, False]: + for M in Ms: + for K in Ks: + _test_accuracy_once(M, K, dtype, device, is_sf_swizzled_layout) print("\nAll accuracy tests passed!") @@ -145,7 +178,7 @@ if __name__ == "__main__": "--models", nargs="+", type=str, - default=["meta-llama/Llama-3.1-8B-Instruct"], + default=["meta-llama/Llama-3.3-70B-Instruct"], choices=list(WEIGHT_SHAPES.keys()), ) parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) diff --git a/csrc/ops.h b/csrc/ops.h index c899535bd..9ee6bda31 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -293,7 +293,8 @@ std::vector cutlass_sparse_compress(torch::Tensor const& a); void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, - torch::Tensor const& input_scale); + torch::Tensor const& input_scale, + bool is_sf_swizzled_layout); void scaled_fp4_experts_quant( torch::Tensor& output, torch::Tensor& output_scale, diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index 2ea229c47..d0264c4d1 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -27,17 +27,24 @@ #include "cuda_utils.h" #include "launch_bounds_utils.h" + +// Define before including nvfp4_utils.cuh so the header +// can use this macro during compilation. +#define NVFP4_ENABLE_ELTS16 1 #include "nvfp4_utils.cuh" namespace vllm { // Use UE4M3 by default. template -__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) - silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, - float const* SFScale, uint32_t* out, - uint32_t* SFout) { - using PackedVec = PackedVec; +__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, + int32_t num_padded_cols, + Type const* __restrict__ in, + float const* __restrict__ SFScale, + uint32_t* __restrict__ out, + uint32_t* __restrict__ SFout) { + using PackedVec = vllm::PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, @@ -49,34 +56,60 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is // (448.f / (Alpha_A / 6.f)). - float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + float const SFScaleVal = (SFScale == nullptr) ? 1.0f : SFScale[0]; + + int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; + int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; // Input tensor row/col loops. for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; - colIdx += blockDim.x) { + if (colIdx < num_padded_cols) { + PackedVec in_vec; + PackedVec in_vec2; int64_t inOffset = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - PackedVec in_vec2 = reinterpret_cast(in)[inOffset2]; - // Get the output tensor offset. - // Same as inOffset because 8 elements are packed into one uint32_t. - int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; - auto& out_pos = out[outOffset]; + bool valid = (rowIdx < numRows) && (elem_idx < numCols); + if constexpr (CVT_FP4_PACK16) { + ld256_or_zero_cg_u32( + in_vec, &reinterpret_cast(in)[inOffset * 8], + valid); + ld256_or_zero_cg_u32( + in_vec2, &reinterpret_cast(in)[inOffset2 * 8], + valid); + } else { + ld128_or_zero_cg_u32( + in_vec, &reinterpret_cast(in)[inOffset * 4], + valid); + ld128_or_zero_cg_u32( + in_vec2, &reinterpret_cast(in)[inOffset2 * 4], + valid); + } // Compute silu and mul - PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2); + PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2); auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( rowIdx, colIdx, numKTiles, SFout); - out_pos = cvt_warp_fp16_to_fp4(out_silu_mul, SFScaleVal, - sf_out); + auto out_val = + cvt_warp_fp16_to_fp4( + out_silu_mul, SFScaleVal, sf_out); + + if (valid) { + if constexpr (CVT_FP4_PACK16) { + int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; + uint64_t packed64 = + (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); + reinterpret_cast(out)[outOffset >> 1] = packed64; + } else { + out[inOffset] = out_val; + } + } } } } @@ -103,17 +136,23 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] auto output_ptr = static_cast(output.data_ptr()); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); - dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); + dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); int const numBlocksPerSM = vllm_runtime_blocks_per_sm(static_cast(block.x)); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); + + int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int grid_x = std::min( + int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); + dim3 grid(grid_x, grid_y); VLLM_DISPATCH_HALF_TYPES( input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { using cuda_type = vllm::CUDATypeConverter::Type; auto input_ptr = static_cast(input.data_ptr()); vllm::silu_mul_cvt_fp16_to_fp4<<>>( - m, n, input_ptr, input_sf_ptr, + m, n, sf_n_unpadded, input_ptr, input_sf_ptr, reinterpret_cast(output_ptr), reinterpret_cast(sf_out)); }); diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index aa573c007..32685c201 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -140,8 +140,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) CVT_FP4_NUM_THREADS_PER_SF>( rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); - out_pos = - cvt_warp_fp16_to_fp4(quant_input, SFScaleVal, sf_out); + out_pos = cvt_warp_fp16_to_fp4( + quant_input, SFScaleVal, sf_out); } } @@ -246,8 +246,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) CVT_FP4_NUM_THREADS_PER_SF>( rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); - out_pos = - cvt_warp_fp16_to_fp4(quant_input, SFScaleVal, sf_out); + out_pos = cvt_warp_fp16_to_fp4( + quant_input, SFScaleVal, sf_out); } } diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index 25e0ba848..650b9da8a 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -21,7 +21,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input, torch::Tensor const& output_sf, - torch::Tensor const& input_sf); + torch::Tensor const& input_sf, + bool is_sf_swizzled_layout); #endif #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ @@ -51,10 +52,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( #endif void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, - torch::Tensor& output_sf, torch::Tensor const& input_sf) { + torch::Tensor& output_sf, torch::Tensor const& input_sf, + bool is_sf_swizzled_layout) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) - return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf); + return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf, + is_sf_swizzled_layout); #endif TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); } diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 8e38deeb6..c27fb69d4 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -27,29 +27,23 @@ #include "cuda_utils.h" #include "launch_bounds_utils.h" + +// Define before including nvfp4_utils.cuh so the header +// can use this macro during compilation. +#define NVFP4_ENABLE_ELTS16 1 #include "nvfp4_utils.cuh" namespace vllm { -template -__host__ __device__ inline Int round_up(Int x, Int y) { - static_assert(std::is_integral_v, - "round_up argument must be integral type"); - return ((x + y - 1) / y) * y; -} - -// Compute effective rows for grid configuration with swizzled SF layouts. -inline int computeEffectiveRows(int m) { - constexpr int ROW_TILE = 128; - return round_up(m, ROW_TILE); -} - // Use UE4M3 by default. template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) - cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, - float const* SFScale, uint32_t* out, uint32_t* SFout) { - using PackedVec = PackedVec; + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, int32_t num_padded_cols, + Type const* __restrict__ in, + float const* __restrict__ SFScale, + uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { + using PackedVec = vllm::PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, @@ -59,33 +53,31 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) int32_t const numKTiles = (numCols + 63) / 64; int sf_m = round_up(numRows, 128); - int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE; - int sf_n_int = round_up(sf_n_unpadded, 4) / 4; - int num_padded_cols = sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE; + int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; + int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; // Get the global scaling factor, which will be applied to the SF. // Note SFScale is the same as next GEMM's alpha, which is // (448.f / (Alpha_A / 6.f)). - float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0]; + float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0]; // Iterate over all rows and cols including padded ones - // ensures we visit every single scale factor address to initialize it. for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) { - for (int colIdx = threadIdx.x; - colIdx < num_padded_cols / CVT_FP4_ELTS_PER_THREAD; - colIdx += blockDim.x) { - int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; - + if (colIdx < num_padded_cols) { PackedVec in_vec; int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; // If we are outside valid rows OR outside valid columns -> Use Zeros - if (rowIdx >= numRows || elem_idx >= numCols) { - memset(&in_vec, 0, sizeof(PackedVec)); - + bool valid = (rowIdx < numRows) && (elem_idx < numCols); + if constexpr (CVT_FP4_PACK16) { + ld256_or_zero_cg_u32( + in_vec, &reinterpret_cast(in)[inOffset * 8], + valid); } else { - // Valid Region: Load actual data - in_vec = reinterpret_cast(in)[inOffset]; + ld128_or_zero_cg_u32( + in_vec, &reinterpret_cast(in)[inOffset * 4], + valid); } auto sf_out = @@ -94,13 +86,85 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) rowIdx, colIdx, numKTiles, SFout); auto out_val = - cvt_warp_fp16_to_fp4(in_vec, global_scale, sf_out); + cvt_warp_fp16_to_fp4( + in_vec, global_scale, sf_out); // We do NOT write output for padding because the 'out' tensor is not // padded. - if (rowIdx < numRows && elem_idx < numCols) { - // Same as inOffset because 8 elements are packed into one uint32_t. - out[inOffset] = out_val; + if (valid) { + if constexpr (CVT_FP4_PACK16) { + int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; + uint64_t packed64 = + (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); + reinterpret_cast(out)[outOffset >> 1] = packed64; + } else { + out[inOffset] = out_val; + } + } + } + } +} + +// Use UE4M3 by default. +template +__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols, + int32_t sf_n_unpadded, Type const* __restrict__ in, + float const* __restrict__ SFScale, + uint32_t* __restrict__ out, + uint32_t* __restrict__ SFout) { + using PackedVec = PackedVec; + + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; + int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0]; + + // Iterate over all rows and cols including padded ones - + // ensures we visit every single scale factor address to initialize it. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + if (colIdx < sf_n_unpadded) { + PackedVec in_vec; + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + + // If we are outside valid rows OR outside valid columns -> Use Zeros + bool valid = (rowIdx < numRows) && (elem_idx < numCols); + if constexpr (CVT_FP4_PACK16) { + ld256_or_zero_cg_u32( + in_vec, &reinterpret_cast(in)[inOffset * 8], + valid); + } else { + ld128_or_zero_cg_u32( + in_vec, &reinterpret_cast(in)[inOffset * 4], + valid); + } + + auto sf_out = + sf_out_rowmajor_u8(rowIdx, colIdx, sf_n_unpadded, SFout); + + auto out_val = + cvt_warp_fp16_to_fp4( + in_vec, global_scale, sf_out); + + // We do NOT write output for padding because the 'out' tensor is not + // padded. + if (valid) { + if constexpr (CVT_FP4_PACK16) { + int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; + uint64_t packed64 = + (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); + reinterpret_cast(out)[outOffset >> 1] = packed64; + } else { + out[inOffset] = out_val; + } } } } @@ -111,7 +175,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input, torch::Tensor const& output_sf, - torch::Tensor const& input_sf) { + torch::Tensor const& input_sf, + bool is_sf_swizzled_layout) { int32_t m = input.size(0); int32_t n = input.size(1); @@ -129,19 +194,48 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); + // Grid, Block size. Each thread converts 8 values. dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); int const numBlocksPerSM = vllm_runtime_blocks_per_sm(static_cast(block.x)); - int effectiveRows = vllm::computeEffectiveRows(m); - dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); - VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - auto input_ptr = static_cast(input.data_ptr()); - // NOTE: We don't support e8m0 scales at this moment. - vllm::cvt_fp16_to_fp4<<>>( - m, n, input_ptr, input_sf_ptr, reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); -} \ No newline at end of file + if (is_sf_swizzled_layout) { + int sf_n_int = int(vllm::round_up(sf_n_unpadded, 4) / 4); + int32_t num_padded_cols = + sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; + + int grid_y = vllm::div_round_up(num_padded_cols, static_cast(block.x)); + int grid_x = + std::min(vllm::computeEffectiveRows(m), + std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); + dim3 grid(grid_x, grid_y); + + VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + // NOTE: We don't support e8m0 scales at this moment. + vllm::cvt_fp16_to_fp4<<>>( + m, n, num_padded_cols, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); + } else { + int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int grid_x = std::min( + m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); + dim3 grid(grid_x, grid_y); + + VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + // NOTE: We don't support e8m0 scales at this moment. + vllm::cvt_fp16_to_fp4_sf_major + <<>>(m, n, sf_n_unpadded, input_ptr, + input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); + } +} diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index 7082ad684..3e7adb9e2 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -19,9 +19,17 @@ #include #include -#define ELTS_PER_THREAD 8 - +#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ + defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) + #define ELTS_PER_THREAD 16 +constexpr int CVT_FP4_ELTS_PER_THREAD = 16; +constexpr bool CVT_FP4_PACK16 = true; +#else + #define ELTS_PER_THREAD 8 constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr bool CVT_FP4_PACK16 = false; +#endif + constexpr int CVT_FP4_SF_VEC_SIZE = 16; namespace vllm { @@ -68,19 +76,46 @@ struct TypeConverter<__nv_bfloat16> { using Type = __nv_bfloat162; }; +#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ + defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) +// Define a 32 bytes packed data type. +template +struct alignas(32) PackedVec { + typename TypeConverter::Type elts[8]; +}; +#else // Define a 16 bytes packed data type. template -struct PackedVec { +struct alignas(16) PackedVec { typename TypeConverter::Type elts[4]; }; +#endif template <> struct PackedVec<__nv_fp8_e4m3> { __nv_fp8x2_e4m3 elts[8]; }; +template +__host__ __device__ inline Int round_up(Int x, Int y) { + static_assert(std::is_integral_v, + "round_up argument must be integral type"); + return ((x + y - 1) / y) * y; +} + +template +__host__ __device__ __forceinline__ Int div_round_up(Int x, Int y) { + return (x + y - 1) / y; +} + +// Compute effective rows for grid configuration with swizzled SF layouts. +inline int computeEffectiveRows(int m) { + constexpr int ROW_TILE = 128; + return round_up(m, ROW_TILE); +} + // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { +inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { uint32_t val; asm volatile( "{\n" @@ -101,7 +136,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { } // Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +__device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) { uint32_t val; asm volatile( "{\n" @@ -114,20 +149,115 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" + "}\n" : "=r"(val) : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); return val; } +struct u32x2 { + uint32_t lo, hi; +}; + +using fp4_packed_t = std::conditional_t; + +__device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) { + u32x2 out; + asm volatile( + "{\n" + ".reg .b8 b0;\n" + ".reg .b8 b1;\n" + ".reg .b8 b2;\n" + ".reg .b8 b3;\n" + ".reg .b8 b4;\n" + ".reg .b8 b5;\n" + ".reg .b8 b6;\n" + ".reg .b8 b7;\n" + "cvt.rn.satfinite.e2m1x2.f32 b0, %3, %2;\n" + "cvt.rn.satfinite.e2m1x2.f32 b1, %5, %4;\n" + "cvt.rn.satfinite.e2m1x2.f32 b2, %7, %6;\n" + "cvt.rn.satfinite.e2m1x2.f32 b3, %9, %8;\n" + "cvt.rn.satfinite.e2m1x2.f32 b4, %11, %10;\n" + "cvt.rn.satfinite.e2m1x2.f32 b5, %13, %12;\n" + "cvt.rn.satfinite.e2m1x2.f32 b6, %15, %14;\n" + "cvt.rn.satfinite.e2m1x2.f32 b7, %17, %16;\n" + "mov.b32 %0, {b0, b1, b2, b3};\n" + "mov.b32 %1, {b4, b5, b6, b7};\n" + "}\n" + : "=r"(out.lo), "=r"(out.hi) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y), + "f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y), + "f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y)); + return out; +} + +__device__ __forceinline__ uint32_t pack_fp4(float2 (&v)[4]) { + return fp32_vec8_to_e2m1(v); +} + +__device__ __forceinline__ u32x2 pack_fp4(float2 (&v)[8]) { + return fp32_vec16_to_e2m1(v); +} + // Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { +__device__ __forceinline__ float reciprocal_approximate_ftz(float a) { float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(b) : "f"(a)); return b; } +template +__device__ __forceinline__ void ld128_or_zero_cg_u32(PackedVec& out, + const void* ptr, + bool pred) { + uint32_t r0, r1, r2, r3; + + asm volatile( + "{\n" + " .reg .pred pr;\n" + " setp.ne.u32 pr, %4, 0;\n" + " mov.u32 %0, 0;\n" + " mov.u32 %1, 0;\n" + " mov.u32 %2, 0;\n" + " mov.u32 %3, 0;\n" + " @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n" + "}\n" + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) + : "r"((int)pred), "l"(ptr)); + + *reinterpret_cast(&out) = uint4{r0, r1, r2, r3}; +} + +template +__device__ __forceinline__ void ld256_or_zero_cg_u32(PackedVec& out, + const void* ptr, + bool pred) { + uint32_t r0, r1, r2, r3, r4, r5, r6, r7; + + asm volatile( + "{\n" + " .reg .pred pr;\n" + " setp.ne.u32 pr, %8, 0;\n" + " mov.u32 %0, 0;\n" + " mov.u32 %1, 0;\n" + " mov.u32 %2, 0;\n" + " mov.u32 %3, 0;\n" + " mov.u32 %4, 0;\n" + " mov.u32 %5, 0;\n" + " mov.u32 %6, 0;\n" + " mov.u32 %7, 0;\n" + " @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n" + "}\n" + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4), "=r"(r5), "=r"(r6), + "=r"(r7) + : "r"((int)pred), "l"(ptr)); + + reinterpret_cast(&out)[0] = uint4{r0, r1, r2, r3}; + reinterpret_cast(&out)[1] = uint4{r4, r5, r6, r7}; +} + // Compute SF output offset for swizzled tensor core layout. // SF layout: [numMTiles, numKTiles, 32, 4, 4] // Caller must precompute: numKTiles = (numCols + 63) / 64 @@ -166,21 +296,41 @@ __device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset( return reinterpret_cast(SFout) + SFOffset; } +template +__device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack, + int packs_per_row_sf, + SFType* SFout) { + constexpr int PACK = CVT_FP4_ELTS_PER_THREAD; + constexpr int THREADS_PER_SF = + CVT_FP4_SF_VEC_SIZE / PACK; // 1 if PACK=16, 2 else PACK=8 + + if (threadIdx.x % THREADS_PER_SF != 0) return nullptr; + + int sf_col = + pack / THREADS_PER_SF; // PACK=16 => sf_col=pack; PACK=8 => sf_col=pack/2 + int64_t off = (int64_t)row * packs_per_row_sf + sf_col; + + return (uint8_t*)SFout + off; +} + // Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, - uint8_t* SFout) { +template +__device__ __forceinline__ fp4_packed_t +cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { // Get absolute maximum values among the local 8 values. auto localMax = __habs2(vec.elts[0]); -// Local maximum value. + // Local maximum value. #pragma unroll for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { localMax = __hmax2(localMax, __habs2(vec.elts[i])); } // Get the absolute maximum among all 16 values (two threads). - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + + if constexpr (CVT_FP4_NUM_THREADS_PER_SF == 2) { + localMax = __hmax2(__shfl_xor_sync(0xffffffffu, localMax, 1), localMax); + } // Get the final absolute maximum values. float vecMax = float(__hmax(localMax.x, localMax.y)); @@ -205,18 +355,17 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, // Convert back to fp32. SFValue = float(tmp); } + + // Write the SF to global memory (STG.8). + if (SFout) *SFout = fp8SFVal; + // Get the output scale. // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * // reciprocal(SFScaleVal)) float outputScale = - SFValue != 0 ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } + SFValue != 0.0f ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; // Convert the input to float. float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; @@ -233,10 +382,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, } // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; + return pack_fp4(fp2Vals); } // silu in float32 diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 68257bdda..cdaf873a1 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -546,7 +546,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute NVFP4 block quantized tensor. ops.def( "scaled_fp4_quant(Tensor! output, Tensor input," - " Tensor! output_scale, Tensor input_scale) -> ()"); + " Tensor! output_scale, Tensor input_scale, bool " + "is_sf_swizzled_layout) -> ()"); ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); // Compute NVFP4 experts quantization. diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py index d615bb7dc..04e28dd20 100644 --- a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -107,10 +107,14 @@ def test_flashinfer_nvfp4_gemm( # from checkpoints are in linear scales. # So instead of needing to swizzle for cutlass as in modelopt.py, # we need to unswizzle for trtllm here. - a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale, backend) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant( + a_dtype, a_global_scale, is_sf_swizzled_layout=True, backend=backend + ) is_sf_128x4_layout = not (backend == "trtllm" and m <= 32) - b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) + b_fp4, b_scale_interleaved = ops.scaled_fp4_quant( + b_dtype, b_global_scale, is_sf_swizzled_layout=True + ) # get_ref_results unswizzles the scales internally. expected_out = get_ref_results( diff --git a/tests/kernels/quantization/test_nvfp4_quant.py b/tests/kernels/quantization/test_nvfp4_quant.py index d17c69663..1d2f9d413 100644 --- a/tests/kernels/quantization/test_nvfp4_quant.py +++ b/tests/kernels/quantization/test_nvfp4_quant.py @@ -27,6 +27,12 @@ PAD_SHAPES = [ (150, 128), (150, 48), (90, 80), + (128, 512), + (128, 1024), + (128, 2048), + (64, 7168), + (64, 7152), + (32, 14336), ] SEEDS = [42] CUDA_DEVICES = ["cuda:0"] @@ -173,3 +179,25 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: out_ans = cast_from_fp4(out, m, n) torch.testing.assert_close(out_ans, out_ref) torch.testing.assert_close(scale_ans, scale_ref) + + +@pytest.mark.parametrize("pad_shape", PAD_SHAPES) +@torch.inference_mode() +def test_quantize_to_fp4_padded_no_sf_swizzled(pad_shape: tuple[int, int]) -> None: + dtype = torch.float16 + set_random_seed(42) + torch.set_default_device("cuda:0") + + m, n = pad_shape + + x = torch.randn((m, n), dtype=dtype) + + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + out_ref, scale_ref = ref_nvfp4_quant(x, global_scale) + + out, out_scale = ops.scaled_fp4_quant(x, global_scale, is_sf_swizzled_layout=False) + scale_ans = out_scale.to(torch.float32) + out_ans = cast_from_fp4(out, m, n) + torch.testing.assert_close(out_ans, out_ref) + torch.testing.assert_close(scale_ans, scale_ref) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ff63aef38..20f399d7f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1534,6 +1534,7 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: def scaled_fp4_quant( input: torch.Tensor, input_global_scale: torch.Tensor, + is_sf_swizzled_layout: bool = True, backend: str = "none", ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -1577,22 +1578,26 @@ def scaled_fp4_quant( else: # Two fp4 values will be packed into an uint8. output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + if is_sf_swizzled_layout: + # We use the rounded values to store the swizzled values. Due to the + # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. + # So, we first pad the scales to multiples of 128 and 4. Then, the scales + # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: + # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(m, 128) + scale_n = n // block_size + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) + else: + output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8) - # We use the rounded values to store the swizzled values. Due to the - # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. - # So, we first pad the scales to multiples of 128 and 4. Then, the scales - # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: - # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(m, 128) - scale_n = n // block_size - rounded_n = round_up(scale_n, 4) - output_scale = torch.empty( - (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + torch.ops._C.scaled_fp4_quant( + output, input, output_scale, input_global_scale, is_sf_swizzled_layout ) - torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale) - output_scale = output_scale.view(torch.float8_e4m3fn) return output, output_scale diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 8530c0dad..1eb23bf03 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -152,6 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): input=result_silu_mul, output_scale=output_scale, input_scale=scale, + is_sf_swizzled_layout=True, ) return at[1], at[2] diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index b5d162209..d7514a170 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -946,6 +946,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): input=rms, output_scale=output_scale, input_scale=input_global_scale, + is_sf_swizzled_layout=True, ) # quant_out, allreduce_output, output_scale @@ -1043,6 +1044,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): input=rms, output_scale=output_scale, input_scale=input_global_scale, + is_sf_swizzled_layout=True, ) # quant_out, allreduce_output, output_scale diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 69dc2e3a6..618892ad3 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -248,6 +248,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): input=attn_out_view, output_scale=output_scale, input_scale=input_scale, + is_sf_swizzled_layout=True, ) output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) return at2[1], output_scale_view diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index cd89f7c85..a4b20505e 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( mxfp8_e4m3_quantize, ) from vllm.triton_utils import tl, triton -from vllm.utils.flashinfer import flashinfer_fp4_quantize from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -117,9 +116,7 @@ def _nvfp4_quantize( A_scale: torch.Tensor | None, is_sf_swizzled_layout: bool, ) -> tuple[torch.Tensor, torch.Tensor]: - return flashinfer_fp4_quantize( - A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout - ) + return ops.scaled_fp4_quant(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout) def _fp8_quantize( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index d7f34e4f5..762498378 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -191,7 +191,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): # quantize BF16 or FP16 to (FP4 and interleaved block scale) x_fp4, x_blockscale = scaled_fp4_quant( - x, layer.input_global_scale, self.backend + x, + layer.input_global_scale, + is_sf_swizzled_layout=True, + backend=self.backend, ) mm_args = ( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e65d23e36..f26aa045b 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1307,7 +1307,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv, self.backend) + x_fp4, x_blockscale = scaled_fp4_quant( + x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend + ) # validate dtypes of quantized input, input block scale, # weight and weight_blockscale diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 5130f6e40..52d0a5c47 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -8,6 +8,7 @@ import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -341,10 +342,8 @@ def flashinfer_trtllm_fp4_moe( hidden_states_fp4, hidden_states_scale_linear_fp4 = x else: # hidden_states is the already quantized - (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - layer.a1_gscale, - is_sf_swizzled_layout=False, + (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant( + x, layer.a1_gscale, is_sf_swizzled_layout=False ) # Determine routing method type @@ -443,10 +442,8 @@ def flashinfer_trtllm_fp4_routed_moe( hidden_states_fp4, hidden_states_scale_linear_fp4 = x else: # Quantize input to FP4 - (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - layer.a1_gscale, - is_sf_swizzled_layout=False, + (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant( + x, layer.a1_gscale, is_sf_swizzled_layout=False ) # Call TRT-LLM FP4 block-scale MoE kernel