[4/n] Migrate FP4/W4A8 CUTLASS kernels to torch stable ABI (#37503)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
@@ -103,6 +103,102 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||
"bool");
|
||||
|
||||
// CUTLASS nvfp4 block scaled GEMM
|
||||
ops.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||
" Tensor alpha) -> ()");
|
||||
|
||||
// cutlass nvfp4 block scaled group GEMM
|
||||
ops.def(
|
||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
|
||||
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
|
||||
|
||||
// Out variant
|
||||
// TODO: Add out_variant tag once PyTorch supports it (added in 2.11)
|
||||
// This registration is now migrated to stable ABI
|
||||
// at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not
|
||||
// yet in torch/headeronly), the tag should be applied from Python
|
||||
// via torch.library.Library.define(..., tags=(torch.Tag.out_variant,))
|
||||
// with the .impl remaining in C++.
|
||||
// See pytorch/pytorch#176117.
|
||||
ops.def(
|
||||
"scaled_fp4_quant.out(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
|
||||
"-> ()");
|
||||
|
||||
// Compute NVFP4 experts quantization.
|
||||
ops.def(
|
||||
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
|
||||
// Fused SiLU+Mul+NVFP4 experts quantization.
|
||||
ops.def(
|
||||
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
|
||||
"output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
|
||||
// Fused SiLU+Mul+NVFP4 quantization.
|
||||
ops.def(
|
||||
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
|
||||
"Tensor input, Tensor input_global_scale) -> ()");
|
||||
|
||||
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
|
||||
// of the given capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
|
||||
|
||||
// CUTLASS w4a8 GEMM
|
||||
ops.def(
|
||||
"cutlass_w4a8_mm("
|
||||
" Tensor A,"
|
||||
" Tensor B,"
|
||||
" Tensor group_scales,"
|
||||
" int group_size,"
|
||||
" Tensor channel_scales,"
|
||||
" Tensor token_scales,"
|
||||
" ScalarType? out_type,"
|
||||
" str? maybe_schedule"
|
||||
") -> Tensor");
|
||||
|
||||
// pack scales
|
||||
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
|
||||
|
||||
// encode and reorder weight matrix
|
||||
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
|
||||
|
||||
// CUTLASS w4a8 grouped GEMM
|
||||
ops.def(
|
||||
"cutlass_w4a8_moe_mm("
|
||||
" Tensor! out_tensors,"
|
||||
" Tensor a_tensors,"
|
||||
" Tensor b_tensors,"
|
||||
" Tensor a_scales,"
|
||||
" Tensor b_scales,"
|
||||
" Tensor b_group_scales,"
|
||||
" int b_group_size,"
|
||||
" Tensor expert_offsets,"
|
||||
" Tensor problem_sizes,"
|
||||
" Tensor a_strides,"
|
||||
" Tensor b_strides,"
|
||||
" Tensor c_strides,"
|
||||
" Tensor group_scale_strides,"
|
||||
" str? maybe_schedule"
|
||||
") -> ()");
|
||||
|
||||
ops.def(
|
||||
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
|
||||
"Tensor)");
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -128,6 +224,18 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
|
||||
ops.impl("get_cutlass_batched_moe_mm_data",
|
||||
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
|
||||
|
||||
// FP4/NVFP4 ops
|
||||
ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm));
|
||||
ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func));
|
||||
ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out));
|
||||
ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant));
|
||||
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
|
||||
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
|
||||
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
|
||||
|
||||
// W4A8 ops: impl registrations are in the source files
|
||||
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -143,6 +251,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
|
||||
TORCH_BOX(&cutlass_group_gemm_supported));
|
||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
|
||||
ops.impl("cutlass_scaled_mm_supports_fp4",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user