feat: NVFP4 mega MoE kernel (scale_vec::4X, UE4M3 block scales)
- New CUDA kernel: sm100_fp8_nvfp4_mega_moe_impl - kGranK=16 (NVFP4 group_size=16, vs MXFP4's 32) - kind::mxf4nvf4.block_scale.scale_vec::4X PTX instruction - float_ue4m3_t scale factor type in instruction descriptor - SF layout: scale_vec::4X (4 TMEM sub-columns per UMMA atom) - UTCCP column stride: i*8 (vs MXFP4's i*4) for 4X layout - L1 epilogue: UE4M3 activation scales (float→cutlass::float_e4m3_t) - SF loading: kNumSFUint32 = kHidden/64 (4 UE4M3 per int32) - New PTX wrappers: SM100_MMA_MXF4NVF4_2x1SM_SS, SM100_MMA_MXF4NVF4_SS - Python API: - fp8_nvfp4_mega_moe() with recipe=(1,1,16) - transform_nvfp4_weights_for_mega_moe() for UE4M3→int32 UTCCP packing - _pack_nvfp4_sf_for_utccp() helper - C++ bindings: - mega_nvfp4.hpp with NVFP4-specific SymmBuffer (SF stride K/16) - JIT kernel header with kGranK=16 TMA descriptors - Registered in python_api.cpp NOTE: Both SFA and SFB must use UE4M3 (scale_format_ is 1-bit, shared). The L1 epilogue converts float→UE4M3 for activation scales.
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include "apis/gemm.hpp"
|
||||
#include "apis/layout.hpp"
|
||||
#include "apis/mega.hpp"
|
||||
#include "apis/mega_nvfp4.hpp"
|
||||
#include "apis/runtime.hpp"
|
||||
|
||||
#ifndef TORCH_EXTENSION_NAME
|
||||
@@ -24,5 +25,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
deep_gemm::gemm::register_apis(m);
|
||||
deep_gemm::layout::register_apis(m);
|
||||
deep_gemm::mega::register_apis(m);
|
||||
deep_gemm::mega::nvfp4::register_apis(m);
|
||||
deep_gemm::runtime::register_apis(m);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user