* Merge with private repo * Update README * Update README * Update README * Add PyTorch requirements * Fix sync scopes for MQA logits (#256) * Update README
24 lines
635 B
C++
24 lines
635 B
C++
#pragma once
|
|
|
|
#include <cute/arch/mma_sm100_desc.hpp>
|
|
// Reuse some types in the JIT modules
|
|
#include <deep_gemm/common/types.cuh>
|
|
|
|
#include "common.hpp"
|
|
#include "../../utils/exception.hpp"
|
|
|
|
namespace deep_gemm {
|
|
|
|
template <typename size_type_t>
|
|
static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) {
|
|
// `> 0` means interleaving
|
|
// 16B actually means non-swizzling (but interleaving)
|
|
for (const int& mode: {128, 64, 32, 16}) {
|
|
if ((block_size * static_cast<int>(elem_size)) % mode == 0)
|
|
return mode;
|
|
}
|
|
DG_HOST_UNREACHABLE("Unreachable");
|
|
}
|
|
|
|
} // namespace deep_gemm
|