[Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)

Signed-off-by: tianrengao <terrygao87@gmail.com>
(cherry picked from commit 3e6a1e1686)
This commit is contained in:
Terry Gao
2026-03-16 15:51:46 -07:00
committed by khluu
parent cdcffafef8
commit eeabf740bb
12 changed files with 213 additions and 44 deletions

View File

@@ -295,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_scale,
torch::Tensor const& input_scale,
bool is_sf_swizzled_layout);
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_scale,
bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_scale,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,