[Kernel] moe wna16 marlin kernel (#14447)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Jinzhen Lin
2025-04-15 11:05:22 +08:00
committed by GitHub
parent 6b40996ae8
commit d06ba4ed3f
16 changed files with 3477 additions and 329 deletions

View File

@@ -9,7 +9,11 @@
#include <cuda_runtime.h>
#include <iostream>
namespace marlin {
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
// Marlin params
@@ -23,6 +27,7 @@ static constexpr int pipe_stages =
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int max_thread_n = 256;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
@@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() {
#endif
} // namespace marlin
} // namespace MARLIN_NAMESPACE_NAME

View File

@@ -5,7 +5,11 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace marlin {
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t>
class ScalarType {};
@@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> {
using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}
@@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> {
#endif
};
} // namespace marlin
} // namespace MARLIN_NAMESPACE_NAME
#endif