* [wip] refactor: compile to .cubin Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * refactor: compile to .cubin and add NVRTC option Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * fix: compiler version Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: compat for old drivers Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: save kernel name to file Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: fix win compat Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * fix: windows compat Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> * feat: make API more general Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: drop support for CUDA<12.3 Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * doc: update README Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * Some lints and refactor * Refactor runtime * Several fixes * Refactor environment variables * Code format * Add a TODO * Compatible with CUDA 12.3 * Fix indent * Fix typing * Drop support for Windows * Add a TODO --------- Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
20 lines
738 B
Plaintext
20 lines
738 B
Plaintext
#pragma once
|
|
|
|
#include "utils.cuh"
|
|
|
|
namespace deep_gemm {
|
|
|
|
// TODO: move this function to other files
|
|
__device__ __forceinline__ void
|
|
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
|
int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) {
|
|
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
|
if (num_tma_multicast == 1) {
|
|
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
|
} else if (cute::block_rank_in_cluster() == 0) {
|
|
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
|
}
|
|
}
|
|
|
|
} // namespace deep_gemm
|