- New CUDA kernel: sm100_fp8_nvfp4_mega_moe_impl - kGranK=16 (NVFP4 group_size=16, vs MXFP4's 32) - kind::mxf4nvf4.block_scale.scale_vec::4X PTX instruction - float_ue4m3_t scale factor type in instruction descriptor - SF layout: scale_vec::4X (4 TMEM sub-columns per UMMA atom) - UTCCP column stride: i*8 (vs MXFP4's i*4) for 4X layout - L1 epilogue: UE4M3 activation scales (float→cutlass::float_e4m3_t) - SF loading: kNumSFUint32 = kHidden/64 (4 UE4M3 per int32) - New PTX wrappers: SM100_MMA_MXF4NVF4_2x1SM_SS, SM100_MMA_MXF4NVF4_SS - Python API: - fp8_nvfp4_mega_moe() with recipe=(1,1,16) - transform_nvfp4_weights_for_mega_moe() for UE4M3→int32 UTCCP packing - _pack_nvfp4_sf_for_utccp() helper - C++ bindings: - mega_nvfp4.hpp with NVFP4-specific SymmBuffer (SF stride K/16) - JIT kernel header with kGranK=16 TMA descriptors - Registered in python_api.cpp NOTE: Both SFA and SFB must use UE4M3 (scale_format_ is 1-bit, shared). The L1 epilogue converts float→UE4M3 for activation scales.
31 lines
896 B
C++
31 lines
896 B
C++
#include <pybind11/pybind11.h>
|
|
#include <torch/python.h>
|
|
|
|
#include "apis/attention.hpp"
|
|
#include "apis/einsum.hpp"
|
|
#include "apis/hyperconnection.hpp"
|
|
#include "apis/gemm.hpp"
|
|
#include "apis/layout.hpp"
|
|
#include "apis/mega.hpp"
|
|
#include "apis/mega_nvfp4.hpp"
|
|
#include "apis/runtime.hpp"
|
|
|
|
#ifndef TORCH_EXTENSION_NAME
|
|
#define TORCH_EXTENSION_NAME _C
|
|
#endif
|
|
|
|
// ReSharper disable once CppParameterMayBeConstPtrOrRef
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.doc() = "DeepGEMM C++ library";
|
|
|
|
// TODO: make SM80 incompatible issues raise errors
|
|
deep_gemm::attention::register_apis(m);
|
|
deep_gemm::einsum::register_apis(m);
|
|
deep_gemm::hyperconnection::register_apis(m);
|
|
deep_gemm::gemm::register_apis(m);
|
|
deep_gemm::layout::register_apis(m);
|
|
deep_gemm::mega::register_apis(m);
|
|
deep_gemm::mega::nvfp4::register_apis(m);
|
|
deep_gemm::runtime::register_apis(m);
|
|
}
|