[kernel] Support W4A8 on Hopper (#23198)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere
2025-08-24 02:18:04 -04:00
committed by GitHub
parent a75277285b
commit e76e233540
12 changed files with 1128 additions and 7 deletions

View File

@@ -309,6 +309,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor");
// conditionally compiled so impl registrations are in source file
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor",
{stride_tag});
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// conditionally compiled so impl registration is in source file
#endif
// Dequantization for GGML.