[Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)

Signed-off-by: tianrengao <terrygao87@gmail.com>
This commit is contained in:
Terry Gao
2026-03-16 15:51:46 -07:00
committed by GitHub
parent 7961486a9b
commit 3e6a1e1686
12 changed files with 213 additions and 44 deletions

View File

@@ -16,6 +16,8 @@
#include <torch/all.h>
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
@@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor const& output_scale_offset_by_experts);
#endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
@@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
}
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::empty(
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
torch::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::empty(
{sf_m, sf_n},
torch::TensorOptions().device(device).dtype(torch::kInt32));
} else {
output_sf = torch::empty(
{m, n / CVT_FP4_SF_VEC_SIZE},
torch::TensorOptions().device(device).dtype(torch::kUInt8));
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,