Add more GPU architectures support (#112)
* Add more GPU architectures support * Update layout.py * Optimize performance, Add SM90 support, Add 1D2D SM100 support * Add fmtlib submodule at commit 553ec11 --------- Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
@@ -1,15 +1,41 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
gemm_fp8_fp8_bf16_nt,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
wgrad_gemm_fp8_fp8_fp32_nt,
|
||||
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt,
|
||||
ceil_div,
|
||||
set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_m_alignment_for_contiguous_layout
|
||||
# Set some default environment provided at setup
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
from .envs import persistent_envs
|
||||
for key, value in persistent_envs.items():
|
||||
if key not in os.environ:
|
||||
os.environ[key] = value
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Import functions from the CPP module
|
||||
import deep_gemm_cpp
|
||||
deep_gemm_cpp.init(
|
||||
os.path.dirname(os.path.abspath(__file__)), # Library root directory path
|
||||
torch.utils.cpp_extension.CUDA_HOME # CUDA home
|
||||
)
|
||||
from .utils import bench, bench_kineto, calc_diff
|
||||
|
||||
# Configs
|
||||
from deep_gemm_cpp import (
|
||||
set_num_sms,
|
||||
get_num_sms
|
||||
)
|
||||
|
||||
# Kernels
|
||||
from deep_gemm_cpp import (
|
||||
fp8_gemm_nt, fp8_gemm_nn,
|
||||
fp8_gemm_tn, fp8_gemm_tt,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
m_grouped_fp8_gemm_nn_contiguous,
|
||||
fp8_m_grouped_gemm_nt_masked,
|
||||
k_grouped_fp8_gemm_tn_contiguous
|
||||
)
|
||||
|
||||
# Some utils
|
||||
from . import testing
|
||||
from . import utils
|
||||
from .utils import *
|
||||
|
||||
213
deep_gemm/include/deep_gemm/common/scheduler.cuh
Normal file
213
deep_gemm/include/deep_gemm/common/scheduler.cuh
Normal file
@@ -0,0 +1,213 @@
|
||||
#pragma once
|
||||
|
||||
#include <deep_gemm/common/types.hpp>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class KGroupedIndexType {
|
||||
MN,
|
||||
K,
|
||||
SF_K,
|
||||
};
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
||||
template <GemmType kGemmType,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
// TODO: refactor this by other values
|
||||
uint32_t kNum1DBlocksPerGroup = 16>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
|
||||
// Block configs
|
||||
uint32_t num_blocks;
|
||||
uint32_t num_m_blocks;
|
||||
uint32_t num_n_blocks;
|
||||
|
||||
// For SM90 multicast checks
|
||||
uint32_t num_blocks_in_group;
|
||||
bool is_peer_cta_alive = true;
|
||||
|
||||
// For grouped GEMM
|
||||
int* grouped_layout;
|
||||
uint32_t current_group_idx;
|
||||
// Only used for masked layout
|
||||
uint32_t current_m_cumsum;
|
||||
// Only used for k-grouped layout
|
||||
uint32_t current_shape_k, current_num_valid_groups, current_k_cumsum, current_sf_k_cumsum;
|
||||
|
||||
// ReSharper disable once CppPossiblyUninitializedMember
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n,
|
||||
int* grouped_layout = nullptr) {
|
||||
num_m_blocks = ceil_div(shape_m, BLOCK_M);
|
||||
num_n_blocks = ceil_div(shape_n, BLOCK_N);
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
num_blocks = num_m_blocks * num_n_blocks;
|
||||
} else if (kGemmType == GemmType::MGroupedContiguous) {
|
||||
num_blocks = num_m_blocks * num_n_blocks;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::MGroupedMasked) {
|
||||
current_group_idx = current_m_cumsum = 0;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::KGroupedContiguous) {
|
||||
current_group_idx = current_num_valid_groups = 0;
|
||||
current_k_cumsum = current_sf_k_cumsum = 0;
|
||||
current_shape_k = __ldg(grouped_layout + current_group_idx);
|
||||
this->grouped_layout = grouped_layout;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks;
|
||||
const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks;
|
||||
const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
|
||||
const auto& group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
|
||||
|
||||
// Fix unaligned TMA multicast
|
||||
// NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast
|
||||
// while SM100 uses 2-CTA, which can not be dynamically disabled
|
||||
#if __CUDA_ARCH__ < 1000
|
||||
if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) {
|
||||
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
|
||||
num_blocks_in_group = num_blocks_in_group ^ 1;
|
||||
} else {
|
||||
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
|
||||
first_block_idx += num_blocks_in_group ^ 1;
|
||||
num_blocks_in_group = 1;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Convert to final M/N block indices
|
||||
if constexpr (kIsMulticastOnA) {
|
||||
m_block_idx = in_group_idx / num_blocks_in_group;
|
||||
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
|
||||
} else {
|
||||
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
|
||||
n_block_idx = in_group_idx / num_blocks_in_group;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kWithGroupOffset, KGroupedIndexType kIndexType = KGroupedIndexType::MN>
|
||||
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
|
||||
const uint32_t& block_idx, const uint32_t& m_block_idx = 0) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
|
||||
const auto offset = kWithGroupOffset ? std::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0;
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
|
||||
const auto offset = kWithGroupOffset ? current_group_idx : 0;
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
|
||||
auto offset = 0;
|
||||
if constexpr (kWithGroupOffset) {
|
||||
if constexpr (kIndexType == KGroupedIndexType::MN)
|
||||
offset = current_group_idx * shape_dim;
|
||||
else if constexpr (kIndexType == KGroupedIndexType::K)
|
||||
offset = current_k_cumsum;
|
||||
else if constexpr (kIndexType == KGroupedIndexType::SF_K)
|
||||
offset = current_sf_k_cumsum;
|
||||
}
|
||||
return offset + block_idx * block_size;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
|
||||
|
||||
if constexpr (kGemmType == GemmType::MGroupedMasked) {
|
||||
while (true) {
|
||||
// End of the task
|
||||
if (current_group_idx == kNumGroups)
|
||||
return false;
|
||||
|
||||
// Within current group
|
||||
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + current_group_idx)), BLOCK_M);
|
||||
const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks;
|
||||
if (next_block_idx < current_m_block_cumsum * num_n_blocks)
|
||||
break;
|
||||
|
||||
// Move to check the next group
|
||||
current_group_idx ++, current_m_cumsum = current_m_block_cumsum;
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx);
|
||||
} else if (kGemmType == GemmType::KGroupedContiguous) {
|
||||
while (true) {
|
||||
// End of the task
|
||||
if (current_group_idx == kNumGroups)
|
||||
return false;
|
||||
|
||||
// Within current group
|
||||
if (current_shape_k > 0 and next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
|
||||
break;
|
||||
|
||||
// Move to check the next group
|
||||
if (current_shape_k > 0) {
|
||||
current_k_cumsum += current_shape_k;
|
||||
current_sf_k_cumsum += ceil_div(current_shape_k, 512u);
|
||||
current_num_valid_groups ++;
|
||||
}
|
||||
if ((++ current_group_idx) != kNumGroups)
|
||||
current_shape_k = __ldg(grouped_layout + current_group_idx);
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx);
|
||||
} else {
|
||||
if (next_block_idx >= num_blocks)
|
||||
return false;
|
||||
|
||||
// For SM90 only
|
||||
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
|
||||
is_peer_cta_alive = kNum1DBlocksPerGroup % kNumMulticast == 0 or // Always aligned on N (constant bypass)
|
||||
num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass)
|
||||
(next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
|
||||
get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// For SM90 only
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
|
||||
if (num_blocks_in_group == 1)
|
||||
return false;
|
||||
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked) {
|
||||
return true;
|
||||
} else {
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type");
|
||||
if constexpr (kIsMulticastOnA) {
|
||||
return true;
|
||||
} else {
|
||||
const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
|
||||
return group_idx == peer_group_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For SM90 only
|
||||
// ReSharper disable once CppNotAllPathsReturnValue
|
||||
__device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return true;
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
|
||||
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
|
||||
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
} // namespace deep_gemm
|
||||
169
deep_gemm/include/deep_gemm/common/sm100_utils.cuh
Normal file
169
deep_gemm/include/deep_gemm/common/sm100_utils.cuh
Normal file
@@ -0,0 +1,169 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/atom/mma_traits_sm100.hpp>
|
||||
#include <cute/arch/mma_sm100_umma.hpp>
|
||||
#include <cute/arch/tmem_allocator_sm100.hpp>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
|
||||
namespace deep_gemm::sm100 {
|
||||
|
||||
template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
|
||||
constexpr uint32_t get_inner_block_atom_size() {
|
||||
return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
|
||||
}
|
||||
|
||||
template <uint32_t BLOCK_INNER, uint32_t BLOCK_OUTER,
|
||||
uint32_t kSwizzleMode, uint32_t kNumMulticast,
|
||||
typename dtype_t>
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr,
|
||||
dtype_t* smem_ptr, const uint32_t& inner_idx, const int32_t& outer_idx) {
|
||||
DG_STATIC_ASSERT(1 <= kNumMulticast and kNumMulticast <= 2, "Invalid multicast config");
|
||||
DG_STATIC_ASSERT(static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL) ==
|
||||
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint");
|
||||
|
||||
// 2-CTA function will send signals to the leader CTA only
|
||||
const auto copy_func = kNumMulticast == 1 ? cute::SM90_TMA_LOAD_2D::copy : cute::SM100_TMA_2SM_LOAD_2D::copy;
|
||||
|
||||
// Issue multiple TMAs
|
||||
constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size<BLOCK_INNER, kSwizzleMode, dtype_t>();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
|
||||
copy_func(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
|
||||
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
||||
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__
|
||||
cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr,
|
||||
uint32_t stride_byte_offset, uint32_t leading_byte_offset) {
|
||||
cute::UMMA::SmemDescriptor desc;
|
||||
|
||||
// Set the version for SM100
|
||||
desc.version_ = 1;
|
||||
|
||||
// Legacy mode
|
||||
desc.lbo_mode_ = 0;
|
||||
|
||||
// Layout
|
||||
desc.layout_type_ = static_cast<uint8_t>(layout);
|
||||
|
||||
// Start address
|
||||
const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
|
||||
|
||||
// Base offset
|
||||
desc.base_offset_ = 0;
|
||||
|
||||
// SBO and LBO
|
||||
desc.stride_byte_offset_ = stride_byte_offset >> 4;
|
||||
desc.leading_byte_offset_ = leading_byte_offset >> 4;
|
||||
|
||||
return desc;
|
||||
}
|
||||
|
||||
__device__ __forceinline__
|
||||
cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) {
|
||||
// NOTES: the UTCCP layout is K-major by default
|
||||
// Atom size: 8 x 128 bits
|
||||
// {SBO, LBO} means the byte stride between atoms on {MN, K}
|
||||
// Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero
|
||||
return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0);
|
||||
}
|
||||
|
||||
__device__ __forceinline__
|
||||
void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) {
|
||||
const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
|
||||
}
|
||||
|
||||
// ReSharper disable once CppNotAllPathsReturnValue
|
||||
template <uint32_t kSwizzleMode>
|
||||
constexpr static cute::UMMA::LayoutType to_umma_layout_type() {
|
||||
DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
|
||||
kSwizzleMode == 32 or kSwizzleMode == 64 or
|
||||
kSwizzleMode == 128, "Invalid swizzling mode");
|
||||
if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE;
|
||||
if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE;
|
||||
if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B;
|
||||
if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B;
|
||||
if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B;
|
||||
}
|
||||
|
||||
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
|
||||
__device__ __forceinline__
|
||||
constexpr uint32_t get_umma_desc_stride_k() {
|
||||
return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
|
||||
}
|
||||
|
||||
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
|
||||
__device__ __forceinline__
|
||||
uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) {
|
||||
return base + ((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) >> 4u);
|
||||
}
|
||||
|
||||
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
|
||||
__device__ __forceinline__
|
||||
cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
|
||||
const uint32_t stride_k = get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
|
||||
if constexpr (kMajorMode == cute::UMMA::Major::K) {
|
||||
// NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
|
||||
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
|
||||
|
||||
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
|
||||
// {SBO, LBO} means the byte stride between atoms on {MN, K}
|
||||
// NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
|
||||
const uint32_t stride_byte_offset = 8 * BLOCK_K * sizeof(dtype_t);
|
||||
const uint32_t leading_byte_offset = 0;
|
||||
return make_smem_desc(to_umma_layout_type<kSwizzleMode>(),
|
||||
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
|
||||
stride_byte_offset, leading_byte_offset);
|
||||
} else {
|
||||
constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
|
||||
|
||||
// Must have no in-atom MN-idx
|
||||
// NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
|
||||
DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
|
||||
DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
|
||||
|
||||
// Atom size: `kSwizzleMode` (in bytes, on MN) x 8
|
||||
// NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
|
||||
// {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
|
||||
// {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
|
||||
uint32_t stride_byte_offset = 8 * BLOCK_MN_ATOM * sizeof(dtype_t);
|
||||
uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
|
||||
if constexpr (kSwizzleMode == 16)
|
||||
swap(stride_byte_offset, leading_byte_offset);
|
||||
return make_smem_desc(to_umma_layout_type<kSwizzleMode>(),
|
||||
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
|
||||
stride_byte_offset, leading_byte_offset);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__
|
||||
uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sf_id) {
|
||||
desc.a_sf_id_ = sf_id, desc.b_sf_id_ = sf_id;
|
||||
return static_cast<uint64_t>(static_cast<uint32_t>(desc)) << 32;
|
||||
}
|
||||
|
||||
template <uint32_t kNumCols>
|
||||
__device__ constexpr uint32_t get_num_aligned_tmem_cols() {
|
||||
DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns");
|
||||
if (kNumCols <= 32) return 32;
|
||||
if (kNumCols <= 64) return 64;
|
||||
if (kNumCols <= 128) return 128;
|
||||
if (kNumCols <= 256) return 256;
|
||||
return 512;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void tcgen05_before_thread_sync() {
|
||||
asm volatile("tcgen05.fence::before_thread_sync;");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void tcgen05_after_thread_sync() {
|
||||
asm volatile("tcgen05.fence::after_thread_sync;");
|
||||
}
|
||||
|
||||
} // namespace `deep_gemm::sm100`
|
||||
@@ -1,149 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef __CUDACC_RTC__
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include <cstdint>
|
||||
#include <cute/arch/mma_sm90_gmma.hpp>
|
||||
#include <cute/arch/mma_sm90_gmma_ext.hpp>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
|
||||
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
||||
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x4_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
|
||||
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
|
||||
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
|
||||
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
|
||||
}
|
||||
};
|
||||
|
||||
__forceinline__ __device__ void warpgroup_arrive() {
|
||||
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void warpgroup_commit_batch() {
|
||||
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void warpgroup_fence_operand(float& reg) {
|
||||
asm volatile("" : "+f"(reg) :: "memory");
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint32_t get_lane_id() {
|
||||
uint32_t lane_id;
|
||||
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
|
||||
uint32_t ret;
|
||||
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
|
||||
int4 ret;
|
||||
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
|
||||
float ret;
|
||||
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) {
|
||||
float2 ret;
|
||||
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
||||
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
||||
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
|
||||
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y));
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ void warpgroup_wait() {
|
||||
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
||||
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
||||
}
|
||||
|
||||
union GmmaDescriptor {
|
||||
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint64_t desc_;
|
||||
uint32_t reg32_[2];
|
||||
uint16_t reg16_[4];
|
||||
|
||||
struct {
|
||||
uint16_t start_address_: 14, : 2;
|
||||
uint16_t leading_byte_offset_: 14, : 2;
|
||||
uint16_t stride_byte_offset_: 14, : 2;
|
||||
uint8_t : 1, base_offset_: 3, : 4;
|
||||
uint8_t : 6, layout_type_: 2;
|
||||
} bitfield;
|
||||
|
||||
// Decay to an `uint64_t`
|
||||
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
|
||||
};
|
||||
|
||||
template <class PointerType>
|
||||
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
|
||||
int leading_byte_offset = 0,
|
||||
int stride_byte_offset = 1024) {
|
||||
GmmaDescriptor desc;
|
||||
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
desc.bitfield.start_address_ = uint_ptr >> 4;
|
||||
desc.bitfield.layout_type_ = layout_type;
|
||||
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
||||
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
||||
desc.bitfield.base_offset_ = 0;
|
||||
return desc;
|
||||
}
|
||||
namespace deep_gemm::sm90 {
|
||||
|
||||
template <int N_, typename MMA>
|
||||
struct FP8MMA {
|
||||
|
||||
|
||||
template <size_t ...Idx>
|
||||
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence<Idx...>) {
|
||||
using namespace cute::SM90::GMMA;
|
||||
@@ -194,19 +59,93 @@ struct FP8MMASelector {
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
enum class Layout {
|
||||
RowMajor,
|
||||
ColMajor
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
|
||||
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
||||
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
|
||||
}
|
||||
};
|
||||
|
||||
__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) {
|
||||
return block_m == 64 ? 1 : 2;
|
||||
__forceinline__ __device__ void warpgroup_arrive() {
|
||||
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
||||
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
||||
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
||||
return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
||||
__forceinline__ __device__ void warpgroup_commit_batch() {
|
||||
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
__forceinline__ __device__ void warpgroup_fence_operand(float& reg) {
|
||||
asm volatile("" : "+f"(reg) :: "memory");
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__forceinline__ __device__ void warpgroup_wait() {
|
||||
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
||||
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
||||
}
|
||||
|
||||
// TODO: replace with CUTLASS solution
|
||||
union GmmaDescriptor {
|
||||
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
|
||||
desc_ = t.desc_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint64_t desc_;
|
||||
uint32_t reg32_[2];
|
||||
uint16_t reg16_[4];
|
||||
|
||||
struct {
|
||||
uint16_t start_address_: 14, : 2;
|
||||
uint16_t leading_byte_offset_: 14, : 2;
|
||||
uint16_t stride_byte_offset_: 14, : 2;
|
||||
uint8_t : 1, base_offset_: 3, : 4;
|
||||
uint8_t : 6, layout_type_: 2;
|
||||
} bitfield;
|
||||
|
||||
// Decay to an `uint64_t`
|
||||
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
|
||||
};
|
||||
|
||||
template <class PointerType>
|
||||
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type,
|
||||
const int& leading_byte_offset = 0,
|
||||
const int& stride_byte_offset = 1024) {
|
||||
GmmaDescriptor desc;
|
||||
const auto& uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
desc.bitfield.start_address_ = uint_ptr >> 4;
|
||||
desc.bitfield.layout_type_ = layout_type;
|
||||
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
||||
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
||||
desc.bitfield.base_offset_ = 0;
|
||||
return desc;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& num_tma_multicast) {
|
||||
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
||||
if (num_tma_multicast == 1) {
|
||||
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
} else if (cute::block_rank_in_cluster() == 0) {
|
||||
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace `deep_gemm::sm90`
|
||||
17
deep_gemm/include/deep_gemm/common/types.hpp
Normal file
17
deep_gemm/include/deep_gemm/common/types.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class GemmType {
|
||||
Normal = 0,
|
||||
MGroupedContiguous = 1,
|
||||
MGroupedMasked = 2,
|
||||
KGroupedContiguous = 3,
|
||||
};
|
||||
|
||||
enum class KernelType {
|
||||
Kernel1D1D = 0,
|
||||
Kernel1D2D = 1,
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
138
deep_gemm/include/deep_gemm/common/utils.cuh
Normal file
138
deep_gemm/include/deep_gemm/common/utils.cuh
Normal file
@@ -0,0 +1,138 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
|
||||
asm volatile("trap;");
|
||||
}
|
||||
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_TRAP_ONLY_DEVICE_ASSERT
|
||||
#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) \
|
||||
asm("trap;"); \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_STATIC_ASSERT
|
||||
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <typename FuncT>
|
||||
struct PatternVisitor {
|
||||
FuncT func;
|
||||
|
||||
__device__ __host__
|
||||
explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
|
||||
|
||||
__device__ __host__
|
||||
auto operator [](const uint32_t& i) {
|
||||
return func(i);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ T ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ T align(T a, T b) {
|
||||
return ceil_div(a, b) * b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T constexpr_align(T a, T b) {
|
||||
return constexpr_ceil_div(a, b) * b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T constexpr_gcd(T a, T b) {
|
||||
return b == 0 ? a : constexpr_gcd(b, a % b);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__forceinline__ __device__ void swap(T& a, T& b) {
|
||||
T temp = a;
|
||||
a = b;
|
||||
b = temp;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint32_t get_sm_idx() {
|
||||
uint32_t sm_idx;
|
||||
asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx));
|
||||
return sm_idx;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ uint32_t get_lane_idx() {
|
||||
uint32_t lane_id;
|
||||
asm ("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) {
|
||||
uint32_t ret;
|
||||
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float4 ld_shared(const float4* ptr) {
|
||||
float4 ret;
|
||||
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) {
|
||||
uint4 ret;
|
||||
asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ld_shared(const float* ptr) {
|
||||
float ret;
|
||||
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
||||
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
||||
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
|
||||
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(ptr), "r"(x), "r"(y), "r"(z), "r"(w));
|
||||
}
|
||||
|
||||
template <typename old_t>
|
||||
__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {
|
||||
auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
|
||||
return *reinterpret_cast<int*>(&bf16x2);
|
||||
}
|
||||
|
||||
} // namespace `deep_gemm`
|
||||
@@ -1,363 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include "mma_utils.cuh"
|
||||
#include "scheduler.cuh"
|
||||
#include "tma_utils.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_wgrad_gemm_kernel(uint32_t shape_k,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
|
||||
static constexpr uint32_t ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(SMEM_SCALES_B_SIZE_PER_STAGE, 128U) * 128U;
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
|
||||
const uint32_t shape_k_scales = ceil_div(shape_k, BLOCK_K);
|
||||
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<float*>(smem_buffer);
|
||||
__nv_fp8_e4m3* smem_a[kNumStages];
|
||||
__nv_fp8_e4m3* smem_b[kNumStages];
|
||||
float* smem_scales_a[kNumStages];
|
||||
float* smem_scales_b[kNumStages];
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages + 1];
|
||||
Barrier* empty_barriers[kNumStages + 1];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)
|
||||
+ i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
smem_scales_b[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)
|
||||
+ i * ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fill barriers
|
||||
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages
|
||||
* (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages + 1; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + 1 + i;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
|
||||
// even with TMA multicast disabled, we want to make the behavior aligned
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
full_barriers[kNumStages]->init(1);
|
||||
empty_barriers[kNumStages]->init(1);
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [&](const auto& func) {
|
||||
if constexpr (kNumLastStages == 0) {
|
||||
for (int k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (int k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(num_iterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr int kNumTMARegisters = 40;
|
||||
constexpr int kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<GemmType::Normal, SHAPE_N, BLOCK_M, BLOCK_N, 1, kNumTMAMulticast, kIsTMAMulticastOnA>(SHAPE_M);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
// Assign TMA multicast number into A and B
|
||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
||||
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, m_block_idx * BLOCK_M, num_tma_multicast_a);
|
||||
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
k_idx / BLOCK_K, num_tma_multicast_a);
|
||||
|
||||
// Issue TMA B
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, n_block_idx * BLOCK_N, num_tma_multicast_b);
|
||||
tma_copy(&tensor_map_scales_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_b[s], n_block_idx * BLOCK_N, k_idx / BLOCK_K, num_tma_multicast_b);
|
||||
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
|
||||
// Issue TMA D
|
||||
empty_barriers[kNumStages]->wait((scheduler.current_iter + 1) & 1);
|
||||
auto& full_barrier = *full_barriers[kNumStages];
|
||||
tma_copy(&tensor_map_d, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M, 1);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_D_SIZE);
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
|
||||
const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4;
|
||||
const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8;
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](int s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
|
||||
}
|
||||
};
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Decide the number of scales B to load
|
||||
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
|
||||
float2 scales_b[WGMMA::kNumAccum / 4];
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](int k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
|
||||
// Read A scales
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
|
||||
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Read B scales at the first warpgroup wave
|
||||
if (local_idx == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
|
||||
scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_scales_b[s] + i * 8 + col_idx * 2));
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival at the last warpgroup wave
|
||||
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
||||
empty_barrier_arrive(s);
|
||||
|
||||
// Promote with scales
|
||||
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
const float &scale_b_0 = scales_b[i].x;
|
||||
const float &scale_b_1 = scales_b[i].y;
|
||||
shifted_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
|
||||
shifted_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
|
||||
shifted_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
|
||||
shifted_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait last TMA store to be finished
|
||||
if (k_iter == 0 and scheduler.current_iter > 0) {
|
||||
if (threadIdx.x == 0) {
|
||||
cute::tma_store_wait<0>();
|
||||
empty_barriers[kNumStages]->arrive();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
|
||||
// Wait TMA D arrivals
|
||||
full_barriers[kNumStages]->wait(scheduler.current_iter & 1);
|
||||
|
||||
// Accumulate to D shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||
auto smem_d_0 = reinterpret_cast<float2*>(smem_d + (m_offset + r_0) * BLOCK_N + col_idx * 2);
|
||||
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + r_1) * BLOCK_N + col_idx * 2);
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
float2 d_0 = ld_shared(smem_d_0 + i * 4);
|
||||
st_shared(smem_d_0 + i * 4, {d_0.x + shifted_accum[i * 4 + 0], d_0.y + shifted_accum[i * 4 + 1]});
|
||||
float2 d_1 = ld_shared(smem_d_1 + i * 4);
|
||||
st_shared(smem_d_1 + i * 4, {d_1.x + shifted_accum[i * 4 + 2], d_1.y + shifted_accum[i * 4 + 3]});
|
||||
}
|
||||
}
|
||||
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M);
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false && "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
3
deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh
Normal file
3
deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: add implement
|
||||
601
deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh
Normal file
601
deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh
Normal file
@@ -0,0 +1,601 @@
|
||||
#pragma once
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
|
||||
#include <deep_gemm/common/scheduler.cuh>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm100_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm100;
|
||||
|
||||
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
||||
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t>
|
||||
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
||||
sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_sfa,
|
||||
const __grid_constant__ CUtensorMap tensor_map_sfb,
|
||||
const __grid_constant__ CUtensorMap tensor_map_c,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// GEMM with accumulation must have FP32 output
|
||||
if constexpr (kWithAccumulation)
|
||||
DG_STATIC_ASSERT(std::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t LAYOUT_AD_M = 128;
|
||||
constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M;
|
||||
constexpr uint32_t kNumTMAStoreStages = 2;
|
||||
constexpr uint32_t kNumSFStagesPerLoad = sizeof(uint32_t) / sizeof(cutlass::float_ue8m0_t);
|
||||
constexpr uint32_t kNumUTCCPAlignedElems = 128;
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
|
||||
DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
|
||||
|
||||
// Overwrite shape constants if the compiler gives
|
||||
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
||||
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
||||
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
||||
const uint32_t shape_sf_k = ceil_div(shape_k, BLOCK_K * kNumSFStagesPerLoad);
|
||||
|
||||
// Utils
|
||||
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
|
||||
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
const auto lane_idx = get_lane_idx();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
|
||||
// 2-CTA MMA
|
||||
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
|
||||
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
|
||||
constexpr uint32_t STORE_BLOCK_M = std::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
||||
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
|
||||
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
|
||||
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D");
|
||||
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
|
||||
|
||||
// Share memory sizes
|
||||
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
|
||||
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
|
||||
constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
|
||||
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
|
||||
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
|
||||
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
|
||||
|
||||
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
|
||||
// TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
|
||||
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
|
||||
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
|
||||
constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2;
|
||||
|
||||
// Real tensor memory size and offsets
|
||||
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
|
||||
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
|
||||
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
|
||||
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == 0) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sfa);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sfb);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
if constexpr (kWithAccumulation)
|
||||
cute::prefetch_tma_descriptor(&tensor_map_c);
|
||||
}
|
||||
|
||||
// Data on shared memory (layout as ordered below)
|
||||
cd_dtype_t* smem_cd[kNumTMAStoreStages];
|
||||
cutlass::float_e4m3_t* smem_a[kNumStages];
|
||||
cutlass::float_e4m3_t* smem_b[kNumStages];
|
||||
uint32_t* smem_sfa[kNumStages];
|
||||
uint32_t* smem_sfb[kNumStages];
|
||||
|
||||
// Fill D/A/B pointers
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
|
||||
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fill SFA/SFB
|
||||
auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_sfa[i] = reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
|
||||
smem_sfb[i] = reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer +
|
||||
SMEM_CD_SIZE +
|
||||
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) +
|
||||
kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE));
|
||||
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
||||
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
||||
auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
||||
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
|
||||
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); });
|
||||
|
||||
// Fill the tensor memory pointer
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
// Initialize barriers
|
||||
if (threadIdx.x == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
// Arrive at all CTAs
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(1);
|
||||
// Arrive only at the leader CTA
|
||||
with_sf_full_barriers[i]->init(kNumMulticast * 32);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
|
||||
// Arrive at all CTAs
|
||||
tmem_full_barriers[i]->init(1);
|
||||
// Arrive only at the leader CTA
|
||||
tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
|
||||
// Allocate tensor memory
|
||||
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA>(shape_m, shape_n, grouped_layout);
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
uint32_t phase = 0;
|
||||
auto launch_k_iterations = [&](const auto& func) {
|
||||
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
|
||||
const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K);
|
||||
const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages;
|
||||
|
||||
// TODO: refactor here
|
||||
if (num_last_stages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1)
|
||||
func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages);
|
||||
} else {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1)
|
||||
func(k_iter, DivisibleK{}, false, num_last_stages);
|
||||
func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1;
|
||||
}
|
||||
};
|
||||
|
||||
auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) {
|
||||
DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2,
|
||||
"Too many epilogue stages, please modify the Python heuristic as well");
|
||||
accum_stage_idx == 0 ? func(0) : func(1);
|
||||
};
|
||||
|
||||
// Dispatch warps into different roles
|
||||
if (warp_idx == 0) {
|
||||
// TMA load warp
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait(phase ^ 1);
|
||||
|
||||
// Compute offsets
|
||||
// NOTES: the group is always concatenated with the outer dimension
|
||||
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
|
||||
shape_m, BLOCK_M, m_block_idx);
|
||||
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
|
||||
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
||||
|
||||
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
||||
// And for all m-grouped GEMMs, A must be K-majored
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
|
||||
uint32_t k_block_idx = k_iter * kNumStages + s;
|
||||
uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
|
||||
// Add 2 CTA offsets
|
||||
if constexpr (kNumMulticast > 1) {
|
||||
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
|
||||
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
||||
}
|
||||
|
||||
// Issue TMAs
|
||||
if (cute::elect_one_sync()) {
|
||||
if constexpr (kMajorA == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx);
|
||||
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
|
||||
}
|
||||
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
|
||||
// Issue SFA and SFB TMAs at certain stages
|
||||
// No swizzling, so one TMA for one SF is enough
|
||||
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
|
||||
tma_copy<BLOCK_M, 1, 0, 1>(&tensor_map_sfa, full_barriers[s], smem_sfa[s], m_block_idx * BLOCK_M,
|
||||
scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad)));
|
||||
tma_copy<BLOCK_N, 1, 0, 1>(&tensor_map_sfb, full_barriers[s], smem_sfb[s], n_block_idx * BLOCK_N,
|
||||
scheduler.template get_global_idx<true, KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx));
|
||||
num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t);
|
||||
}
|
||||
|
||||
// Arrive at full barriers
|
||||
if (cute::elect_one_sync())
|
||||
full_barriers[s]->arrive_and_expect_tx(num_arrival_bytes);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait(phase ^ 1);
|
||||
if (cute::elect_one_sync())
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}
|
||||
} else if (warp_idx == 1 and is_leader_cta) {
|
||||
// MMA issue warp
|
||||
// NOTES: only the leader CTA will do this
|
||||
// Make instruction descriptor
|
||||
// TODO: refactor `UMMA_M` calculation
|
||||
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
|
||||
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
|
||||
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
|
||||
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e4m3_t, cutlass::float_e4m3_t,
|
||||
float, cutlass::float_ue8m0_t,
|
||||
UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
||||
auto sf_desc = make_sf_desc(nullptr);
|
||||
|
||||
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
||||
auto a_desc = make_umma_desc<kMajorA, BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
|
||||
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
||||
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
||||
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
||||
|
||||
// Checks for MMA instructions
|
||||
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
|
||||
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
||||
"Invalid MMA instruction shape");
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
|
||||
// Wait tensor memory empty barrier arrival
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) {
|
||||
auto umma_arrive = [](const uint64_t* barrier) {
|
||||
if constexpr (kNumMulticast == 1) {
|
||||
cutlass::arch::umma_arrive(barrier);
|
||||
} else {
|
||||
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
||||
}
|
||||
};
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[s]));
|
||||
|
||||
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
||||
if (do_tmem_full_arrive)
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA and SF-transpose arrival
|
||||
with_sf_full_barriers[s]->wait(phase);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Do SF copy at certain stages
|
||||
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
|
||||
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
|
||||
using cute_utccp_t = std::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
|
||||
|
||||
// SFA and SFB copy
|
||||
// TODO: process shared memory descriptor by addition
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfa[s] + i * kNumUTCCPAlignedElems;
|
||||
replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfb[s] + i * kNumUTCCPAlignedElems;
|
||||
replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Issue UMMA in the leader CTA
|
||||
using cute_mma_t = std::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_MMA_MXF8F6F4_SS <cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
||||
cutlass::float_ue8m0_t, UMMA_M, UMMA_N, kMajorA, kMajorB>,
|
||||
cute::SM100_MMA_MXF8F6F4_2x1SM_SS<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
||||
cutlass::float_ue8m0_t, UMMA_M, UMMA_N, kMajorA, kMajorB>>;
|
||||
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx);
|
||||
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s);
|
||||
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s);
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
|
||||
cute_mma_t::fma(a_desc, b_desc,
|
||||
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
|
||||
k_iter > 0 or s > 0 or k > 0,
|
||||
runtime_instr_desc,
|
||||
kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32),
|
||||
kTmemStartColOfSFB);
|
||||
}
|
||||
}
|
||||
|
||||
// Commit to the mbarrier object
|
||||
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
||||
empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
with_sf_full_barriers[s]->wait(phase);
|
||||
empty_barrier_arrive(s, false);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
} else if (warp_idx == 2) {
|
||||
// UTCCP transposer
|
||||
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
|
||||
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
|
||||
uint32_t values[4];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; ++ i)
|
||||
values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
|
||||
__syncwarp();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; ++ i)
|
||||
st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
|
||||
};
|
||||
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA arrival
|
||||
full_barriers[s]->wait(phase);
|
||||
|
||||
// Transpose for UTCCP at certain stages
|
||||
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
|
||||
utccp_required_smem_warp_transpose(smem_sfa[s] + i * kNumUTCCPAlignedElems);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
|
||||
utccp_required_smem_warp_transpose(smem_sfb[s] + i * kNumUTCCPAlignedElems);
|
||||
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
}
|
||||
|
||||
// Arrive
|
||||
with_sf_full_barriers[s]->arrive(0u);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait(phase);
|
||||
with_sf_full_barriers[s]->arrive(0u);
|
||||
}
|
||||
});
|
||||
}
|
||||
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
|
||||
// Epilogue warp groups
|
||||
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
|
||||
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
|
||||
|
||||
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
||||
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
|
||||
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
||||
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
|
||||
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
||||
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
|
||||
// Flush TMA stores
|
||||
// NOTES: for the first store, we have to flush all previous TMA,
|
||||
// as we don't share pipeline stages between two blocks
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
|
||||
// Wait UMMA arrival
|
||||
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Load from tensor memory into registers, and write shared memory with STSM
|
||||
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
|
||||
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
||||
|
||||
// Iterate over M waves
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
// Issue every swizzled atom and pipeline STSM and TMA store
|
||||
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStores; ++ s) {
|
||||
// Wait shared memory to be released
|
||||
const uint32_t iter_idx = w * kNumStores + s;
|
||||
if (iter_idx >= kNumTMAStoreStages) {
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
}
|
||||
|
||||
// The pipeline stage
|
||||
const auto tma_stage_idx = iter_idx % kNumTMAStoreStages;
|
||||
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
|
||||
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
||||
|
||||
// Store into shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
||||
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
||||
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
||||
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleCDMode / 16);
|
||||
|
||||
// Source and destination memory address
|
||||
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
|
||||
w * BLOCK_N + // Wave offset
|
||||
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
||||
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
||||
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
|
||||
// Load from tensor memory, store into shared memory
|
||||
uint32_t values[kNumElemsPerBankGroup];
|
||||
if constexpr (std::is_same_v<cd_dtype_t, float>) {
|
||||
// For FP32 output, read and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
||||
} else {
|
||||
// For BF16 output, read, cast and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and std::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3],
|
||||
values[4], values[5], values[6], values[7]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr,
|
||||
cast_into_bf16_and_pack(values[0], values[1]),
|
||||
cast_into_bf16_and_pack(values[2], values[3]),
|
||||
cast_into_bf16_and_pack(values[4], values[5]),
|
||||
cast_into_bf16_and_pack(values[6], values[7]));
|
||||
}
|
||||
}
|
||||
|
||||
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
||||
// NOTES: only the last stage needs to do this
|
||||
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
|
||||
tcgen05_before_thread_sync();
|
||||
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Synchronize all threads and issue TMA
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
if (epilogue_thread_idx == 0) {
|
||||
using cute_tma_t = std::conditional_t<kWithAccumulation,
|
||||
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
||||
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
|
||||
// TODO: do we actually need this?
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
|
||||
// Deallocate tensor memory by warp 1
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
// TODO: do we need 2 SM allocation?
|
||||
if (epilogue_warp_idx == 1)
|
||||
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
||||
}
|
||||
|
||||
// To safely deconstruct all barriers, we need a cluster sync
|
||||
// TODO: optimize it by another round of barrier waits
|
||||
if constexpr (kNumMulticast > 1)
|
||||
cute::cluster_sync();
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");
|
||||
#endif
|
||||
}
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
532
deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh
Normal file
532
deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh
Normal file
@@ -0,0 +1,532 @@
|
||||
#pragma once
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <deep_gemm/common/scheduler.cuh>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm100_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm100;
|
||||
|
||||
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
||||
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
GemmType kGemmType, typename cd_dtype_t>
|
||||
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
||||
sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d,
|
||||
const __grid_constant__ CUtensorMap tensor_map_sfa) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t LAYOUT_AD_M = 128;
|
||||
constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M;
|
||||
constexpr uint32_t kNumTMAStoreStages = 2;
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
|
||||
DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
|
||||
DG_STATIC_ASSERT(BLOCK_M == kNumEpilogueThreads, "Invalid block M");
|
||||
|
||||
// Overwrite shape constants if the compiler gives
|
||||
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
||||
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
||||
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
||||
const auto shape_k_scales = ceil_div(shape_k, BLOCK_K);
|
||||
|
||||
// Utils
|
||||
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
|
||||
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
const auto lane_idx = get_lane_idx();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
|
||||
// 2-CTA MMA
|
||||
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
|
||||
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
|
||||
constexpr uint32_t STORE_BLOCK_M = std::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
||||
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
|
||||
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
|
||||
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D");
|
||||
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
|
||||
|
||||
// Share memory sizes
|
||||
// NOTES: do not use `LOAD_BLOCK_M` for SFA, as we need full SFA for promotion
|
||||
constexpr bool kMustUseUniformedSFB = (BLOCK_K % BLOCK_N == 0);
|
||||
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode;
|
||||
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
|
||||
|
||||
// Must have 2 epilogue stages
|
||||
constexpr uint32_t kNumEpilogueStages = 2;
|
||||
|
||||
// Real tensor memory size and offsets
|
||||
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
|
||||
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == 0) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sfa);
|
||||
}
|
||||
|
||||
// Data on shared memory (layout as ordered below)
|
||||
cd_dtype_t* smem_cd[kNumTMAStoreStages];
|
||||
cutlass::float_e4m3_t* smem_a[kNumStages];
|
||||
cutlass::float_e4m3_t* smem_b[kNumStages];
|
||||
float* smem_sfa[kNumStages];
|
||||
|
||||
// Fill D/A/B pointers
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
|
||||
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fill SFA/SFB
|
||||
auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i)
|
||||
smem_sfa[i] = reinterpret_cast<float*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer +
|
||||
SMEM_CD_SIZE +
|
||||
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) +
|
||||
kNumStages * SMEM_SFA_SIZE_PER_STAGE);
|
||||
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
||||
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
||||
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
||||
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
|
||||
|
||||
// Fill the tensor memory pointer
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 2 + kNumEpilogueStages * 2);
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
// Initialize barriers
|
||||
if (threadIdx.x == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
// Arrive at all CTAs
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads / 32);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
|
||||
// Arrive at all CTAs
|
||||
tmem_full_barriers[i]->init(1);
|
||||
// Arrive only at the leader CTA
|
||||
tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
|
||||
// Allocate tensor memory
|
||||
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
const uint32_t num_iterations = ceil_div(shape_k, kNumStages * BLOCK_K);
|
||||
auto launch_k_iterations = [=](const auto& func) {
|
||||
if constexpr (kNumLastStages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(num_iterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA>(shape_m, shape_n, grouped_layout);
|
||||
|
||||
// Register configurations
|
||||
constexpr uint32_t kNumNonEpilogueRegisters = 64;
|
||||
constexpr uint32_t kNumEpilogueRegisters = 216;
|
||||
DG_STATIC_ASSERT(kNumNonEpilogueRegisters * kNumNonEpilogueThreads + kNumEpilogueRegisters * kNumEpilogueThreads <= 65535, "Too many registers");
|
||||
|
||||
// Dispatch warps into different roles
|
||||
if (warp_idx == 0) {
|
||||
// Adjust registers
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
|
||||
|
||||
// TMA load warp
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
|
||||
// Compute offsets
|
||||
// NOTES: the group is always concatenated with the outer dimension
|
||||
uint32_t m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(
|
||||
shape_m, BLOCK_M, m_block_idx);
|
||||
uint32_t n_idx = scheduler.get_global_idx<(kMajorB == cute::UMMA::Major::K)>(
|
||||
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
||||
|
||||
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
||||
// And for all grouped GEMMs, A must be K-majored
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kMajorA == cute::UMMA::Major::K, "Invalid major");
|
||||
uint32_t k_block_idx = k_iter * kNumStages + s;
|
||||
uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
uint32_t k_b_idx = scheduler.get_global_idx<(kMajorB == cute::UMMA::Major::MN)>(
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
|
||||
// Add 2 CTA offsets
|
||||
if constexpr (kNumMulticast > 1) {
|
||||
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
|
||||
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
||||
}
|
||||
|
||||
// Issue TMAs
|
||||
if (cute::elect_one_sync()) {
|
||||
if constexpr (kMajorA == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], k_idx, m_idx);
|
||||
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
|
||||
|
||||
// Issue SFA TMA
|
||||
tma_copy<BLOCK_M, 1, 0, kNumMulticast>(
|
||||
&tensor_map_sfa, full_barriers[s],
|
||||
smem_sfa[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_k_scales, 1, k_block_idx));
|
||||
}
|
||||
|
||||
// Arrive at full barriers
|
||||
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE;
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}
|
||||
} else if (warp_idx == 1 and is_leader_cta) {
|
||||
// Adjust registers
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
|
||||
|
||||
// MMA issue warp
|
||||
// NOTES: only the leader CTA will do this
|
||||
// Make instruction descriptor
|
||||
// TODO: refactor `UMMA_M` calculation
|
||||
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
|
||||
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
|
||||
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
|
||||
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
||||
UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
||||
auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
||||
|
||||
// Checks for MMA instructions
|
||||
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
|
||||
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
||||
"Invalid MMA instruction shape");
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s) {
|
||||
// Wait TMA full
|
||||
auto iter_idx = scheduler.current_iter * num_iterations + k_iter;
|
||||
full_barriers[s]->wait(iter_idx & 1);
|
||||
|
||||
// Wait tensor memory empty
|
||||
auto accum_stage_idx = (iter_idx * kNumStages + s) % kNumEpilogueStages;
|
||||
auto accum_stage_phase = ((iter_idx * kNumStages + s) / kNumEpilogueStages) & 1;
|
||||
tmem_empty_barriers[accum_stage_idx]->wait(accum_stage_phase ^ 1);
|
||||
|
||||
// Issue UMMA in the leader CTA
|
||||
if (s < kNumInnerStages) {
|
||||
using cute_mma_t = std::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_MMA_F8F6F4_SS, cute::SM100_MMA_F8F6F4_2x1SM_SS>;
|
||||
tcgen05_after_thread_sync();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[s], 0, k * UMMA_K);
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
auto a_desc = make_umma_desc<kMajorA, BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[s], w * LAYOUT_AD_M, k * UMMA_K);
|
||||
cute_mma_t::fma(a_desc, b_desc,
|
||||
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
|
||||
k > 0,
|
||||
runtime_instr_desc);
|
||||
}
|
||||
}
|
||||
tcgen05_before_thread_sync();
|
||||
}
|
||||
|
||||
// Commit to the TMA empty and tensor memory full barrier
|
||||
auto umma_arrive = [](const uint64_t* barrier) {
|
||||
if constexpr (kNumMulticast == 1) {
|
||||
cutlass::arch::umma_arrive(barrier);
|
||||
} else {
|
||||
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
||||
}
|
||||
};
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
||||
}
|
||||
});
|
||||
}
|
||||
} else if (warp_idx < kNumNonEpilogueThreads / 32) {
|
||||
// Adjust registers
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
|
||||
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
|
||||
// Adjust registers
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumEpilogueRegisters>();
|
||||
|
||||
// Epilogue warp groups
|
||||
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
|
||||
const auto epilogue_thread_idx_in_warpgroup = epilogue_thread_idx % 128;
|
||||
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
|
||||
const auto epilogue_warpgroup_idx = epilogue_thread_idx / 128;
|
||||
|
||||
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
||||
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
|
||||
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
||||
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
|
||||
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
||||
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
constexpr uint32_t kNumElemsPerLDTM = 16;
|
||||
DG_STATIC_ASSERT(kNumElemsPerLDTM == 16 and BLOCK_N % kNumElemsPerLDTM == 0 and BLOCK_K % kNumElemsPerLDTM == 0, "Invalid LDTM width");
|
||||
|
||||
// SFB stuffs
|
||||
uint32_t num_former_iters = BLOCK_N, num_full_iters = BLOCK_N;
|
||||
if constexpr (not kMustUseUniformedSFB) {
|
||||
num_former_iters = min(BLOCK_N, BLOCK_K - ((n_block_idx * BLOCK_N) % BLOCK_K));
|
||||
num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N);
|
||||
}
|
||||
num_former_iters /= kNumElemsPerLDTM, num_full_iters /= kNumElemsPerLDTM;
|
||||
const auto sfb_offset = scheduler.get_global_idx<true>(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx);
|
||||
const auto sfb_ptr = sfb + (sfb_offset + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales;
|
||||
|
||||
// Launch promotion
|
||||
float accum[BLOCK_N] = {0};
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s) {
|
||||
// Load SFB
|
||||
float sf_0 = 0, sf_1 = 0;
|
||||
if (s < kNumInnerStages) {
|
||||
const auto k_block_idx = k_iter * kNumStages + s;
|
||||
sf_0 = __ldg(sfb_ptr + k_block_idx);
|
||||
sf_1 = num_former_iters < num_full_iters ? __ldg(sfb_ptr + k_block_idx + shape_k_scales) : 0;
|
||||
}
|
||||
|
||||
// Wait UMMA arrival
|
||||
auto iter_idx = scheduler.current_iter * num_iterations + k_iter;
|
||||
auto accum_stage_idx = (iter_idx * kNumStages + s) % kNumEpilogueStages;
|
||||
auto accum_stage_phase = ((iter_idx * kNumStages + s) / kNumEpilogueStages) & 1;
|
||||
tmem_full_barriers[accum_stage_idx]->wait(accum_stage_phase);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Commit to the TMA empty barrier for all CTAs after loading SFA
|
||||
float sfa = s < kNumInnerStages ? ld_shared(smem_sfa[s] + epilogue_thread_idx) : 0;
|
||||
sf_0 *= sfa, sf_1 *= sfa;
|
||||
__syncwarp();
|
||||
if (lane_idx < kNumMulticast)
|
||||
empty_barriers[s]->arrive(lane_idx);
|
||||
__syncwarp();
|
||||
|
||||
// Do promotion like the SM90 kernel
|
||||
if (s < kNumInnerStages) {
|
||||
uint32_t values[kNumElemsPerLDTM];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerLDTM; ++ i) {
|
||||
// Load from tensor memory
|
||||
cute::SM100_TMEM_LOAD_32dp32b16x::copy(
|
||||
accum_stage_idx * kNumMWaves * BLOCK_N + epilogue_warpgroup_idx * BLOCK_N + i * kNumElemsPerLDTM,
|
||||
values[ 0], values[ 1], values[ 2], values[ 3],
|
||||
values[ 4], values[ 5], values[ 6], values[ 7],
|
||||
values[ 8], values[ 9], values[10], values[11],
|
||||
values[12], values[13], values[14], values[15]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
|
||||
// Promote
|
||||
const auto sf = (kMustUseUniformedSFB or i < num_former_iters) ? sf_0 : sf_1;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumElemsPerLDTM; ++ j)
|
||||
accum[i * kNumElemsPerLDTM + j] += *reinterpret_cast<float*>(&values[j]) * sf;
|
||||
}
|
||||
}
|
||||
|
||||
// Commit to the tensor memory empty barrier (only at the leader CTA)
|
||||
tcgen05_before_thread_sync();
|
||||
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
||||
}
|
||||
});
|
||||
|
||||
// Flush TMA stores
|
||||
// NOTES: for the first store, we have to flush all previous TMA,
|
||||
// as we don't share pipeline stages between two blocks
|
||||
if (epilogue_thread_idx_in_warpgroup == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
|
||||
|
||||
// Write shared memory
|
||||
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
||||
|
||||
// Epilogue store and addition
|
||||
// Issue every swizzled atom and pipeline: store shared, add C, and TMA store
|
||||
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStores; ++ s) {
|
||||
// Wait shared memory to be released
|
||||
if (s >= kNumTMAStoreStages) {
|
||||
if (epilogue_thread_idx_in_warpgroup == 0)
|
||||
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
|
||||
}
|
||||
|
||||
// The pipeline stage
|
||||
const auto tma_stage_idx = s % kNumTMAStoreStages;
|
||||
const auto m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_m, BLOCK_M, m_block_idx);
|
||||
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
||||
const auto local_smem_cd = smem_cd[tma_stage_idx] + epilogue_warpgroup_idx * STORE_BLOCK_M * STORE_BLOCK_N;
|
||||
|
||||
// Store into shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
||||
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
||||
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
||||
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleCDMode / 16);
|
||||
|
||||
// Source and destination memory address
|
||||
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
||||
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
|
||||
// Load from tensor memory, store into shared memory
|
||||
// NOTES: if you want to do accumulation, please notice that you need two accumulation barriers
|
||||
const auto offset = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup;
|
||||
if constexpr (std::is_same_v<cd_dtype_t, float>) {
|
||||
// For FP32 output, read and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
st_shared(smem_ptr,
|
||||
*reinterpret_cast<uint32_t*>(&accum[offset + 0]),
|
||||
*reinterpret_cast<uint32_t*>(&accum[offset + 1]),
|
||||
*reinterpret_cast<uint32_t*>(&accum[offset + 2]),
|
||||
*reinterpret_cast<uint32_t*>(&accum[offset + 3]));
|
||||
} else {
|
||||
// For BF16 output, read, cast and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and std::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
st_shared(smem_ptr,
|
||||
cast_into_bf16_and_pack(accum[offset + 0], accum[offset + 1]),
|
||||
cast_into_bf16_and_pack(accum[offset + 2], accum[offset + 3]),
|
||||
cast_into_bf16_and_pack(accum[offset + 4], accum[offset + 5]),
|
||||
cast_into_bf16_and_pack(accum[offset + 6], accum[offset + 7]));
|
||||
}
|
||||
}
|
||||
|
||||
// Synchronize all threads and issue TMA
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
|
||||
if (epilogue_thread_idx_in_warpgroup == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(
|
||||
&tensor_map_d, local_smem_cd,
|
||||
n_idx, m_idx + epilogue_warpgroup_idx * STORE_BLOCK_M);
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
|
||||
// TODO: do we actually need this?
|
||||
if (epilogue_thread_idx_in_warpgroup == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
|
||||
// Deallocate tensor memory by warp 1
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
// TODO: do we need 2 SM allocation?
|
||||
if (epilogue_warp_idx == 1)
|
||||
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
||||
}
|
||||
|
||||
// To safely deconstruct all barriers, we need a cluster sync
|
||||
// TODO: optimize it by another round of barrier waits
|
||||
if constexpr (kNumMulticast > 1)
|
||||
cute::cluster_sync();
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");
|
||||
#endif
|
||||
}
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
3
deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh
Normal file
3
deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: add implement
|
||||
3
deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh
Normal file
3
deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: add implement
|
||||
@@ -10,13 +10,14 @@
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include "mma_utils.cuh"
|
||||
#include "scheduler.cuh"
|
||||
#include "tma_utils.cuh"
|
||||
#include "utils.cuh"
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/scheduler.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm90;
|
||||
|
||||
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd>
|
||||
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) {
|
||||
if (num_former_iters == kNumFormerIters) {
|
||||
@@ -28,59 +29,58 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it
|
||||
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t BLOCK_N_PADDING,
|
||||
uint32_t kSwizzleDMode,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
GemmType kGemmType>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
__global__ void __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
||||
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d,
|
||||
const __grid_constant__ CUtensorMap tensor_map_sfa) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
|
||||
DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
||||
|
||||
// Overwrite shape constants if the compiler gives
|
||||
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
||||
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
||||
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
||||
|
||||
// Shared memory
|
||||
static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
|
||||
static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div<uint32_t>(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier);
|
||||
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K);
|
||||
const uint32_t& smem_sfb_size = align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
|
||||
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
|
||||
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
const uint32_t lane_idx = get_lane_idx();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
|
||||
|
||||
// `tensor_map_d` is only used in swizzling mode
|
||||
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
|
||||
if constexpr (kSwizzleDMode > 0)
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_sfa));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@@ -92,8 +92,8 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
__nv_fp8_e4m3* smem_a[kNumStages];
|
||||
__nv_fp8_e4m3* smem_b[kNumStages];
|
||||
float* smem_scales_a[kNumStages];
|
||||
float* smem_scales_b;
|
||||
float* smem_sfa[kNumStages];
|
||||
float* smem_sfb;
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages];
|
||||
@@ -104,12 +104,12 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
smem_sfa[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SFA_SIZE_PER_STAGE);
|
||||
}
|
||||
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
|
||||
smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE));
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
@@ -129,7 +129,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
@@ -140,7 +140,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
struct NotDivisibleK {};
|
||||
struct SkipComputation {};
|
||||
struct NotSkipComputation {};
|
||||
auto launch_k_iterations = [](const auto& func, bool skip_computation, uint32_t num_former_iters) {
|
||||
auto launch_k_iterations = [=](const auto& func, bool skip_computation, uint32_t num_former_iters) {
|
||||
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
|
||||
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
|
||||
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
|
||||
@@ -149,15 +149,15 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
|
||||
outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) {
|
||||
if (skip_computation) {
|
||||
for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type);
|
||||
} else if (SHAPE_K % kFullKOfAllStages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter)
|
||||
} else if (shape_k % kFullKOfAllStages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
|
||||
} else {
|
||||
for (uint32_t k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
|
||||
func(kNumIterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type);
|
||||
func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type);
|
||||
}
|
||||
}, func, kShouldOptimize ? num_former_iters : 0);
|
||||
};
|
||||
@@ -168,7 +168,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, grouped_layout);
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, shape_n, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
@@ -180,7 +180,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
|
||||
// Assign TMA multicast number into A and B
|
||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||
@@ -194,30 +194,31 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A
|
||||
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
|
||||
num_tma_multicast_a);
|
||||
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K),
|
||||
tma_copy(&tensor_map_sfa, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_sfa[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx<kWithGroupOffsetA>(shape_k_scales, 1, k_idx / BLOCK_K),
|
||||
num_tma_multicast_a);
|
||||
|
||||
// Issue TMA B
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
|
||||
num_tma_multicast_b);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
}, false, 0);
|
||||
@@ -227,7 +228,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -235,33 +236,33 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
||||
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Decide the number of scales B to load
|
||||
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
|
||||
DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0);
|
||||
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
|
||||
if constexpr (not kMustUseUniformedScaleB) {
|
||||
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
|
||||
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
|
||||
num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8;
|
||||
}
|
||||
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
|
||||
uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2);
|
||||
|
||||
// Load B scales with math warp-groups
|
||||
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
||||
if (threadIdx.x >= 32) {
|
||||
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
|
||||
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
|
||||
auto num_previous_lines = scheduler.get_global_idx<true>(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx);
|
||||
auto local_sfb = sfb + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales;
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
|
||||
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
|
||||
for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
|
||||
st_shared(smem_sfb + i, __ldg(local_sfb + i));
|
||||
}
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
|
||||
constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2);
|
||||
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
|
||||
|
||||
@@ -279,19 +280,18 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) {
|
||||
constexpr bool kSkipComputation = std::is_same_v<decltype(skip_type), SkipComputation>;
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 :
|
||||
(kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K);
|
||||
constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
|
||||
float scale_b_0 = ld_shared(smem_sfb + k_iter * kNumStages + s), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
|
||||
scale_b_1 = ld_shared(smem_sfb + k_iter * kNumStages + s + shape_k_scales);
|
||||
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
|
||||
// TODO: remove some useless computation for unaligned Ms
|
||||
#pragma unroll
|
||||
@@ -300,8 +300,8 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
|
||||
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
|
||||
auto scale_a_0 = ld_shared(smem_sfa[s] + r_0 + m_offset);
|
||||
auto scale_a_1 = ld_shared(smem_sfa[s] + r_1 + m_offset);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
@@ -347,7 +347,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
}, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters);
|
||||
@@ -360,8 +360,6 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
|
||||
"Unaligned TMA store or too many TMA store instructions");
|
||||
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
|
||||
DG_STATIC_ASSERT(static_cast<uint32_t>(kSwizzleDMode > 0) + static_cast<uint32_t>(BLOCK_N_PADDING > 0) <= 1,
|
||||
"Swizzling and padding are not compatible");
|
||||
|
||||
// Wait last TMA store to be finished
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
|
||||
@@ -403,9 +401,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
} else {
|
||||
// No swizzling, just padding
|
||||
// NOTES: padding must be zero for BF16 output
|
||||
DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output");
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
|
||||
}
|
||||
|
||||
// NOTES: only 16 lanes' addresses are used
|
||||
@@ -421,13 +417,14 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
// TODO: compatible with FP32 output
|
||||
constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked;
|
||||
DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
||||
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
||||
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||
n_block_idx * BLOCK_N + in_block_n_offset,
|
||||
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx));
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
__syncwarp();
|
||||
@@ -441,4 +438,4 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
#pragma clang diagnostic pop
|
||||
139
deep_gemm/include/deep_gemm/impls/smxx_layout.cuh
Normal file
139
deep_gemm/include/deep_gemm/impls/smxx_layout.cuh
Normal file
@@ -0,0 +1,139 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
// NOTES: the two kernels below always pack the K dimension
|
||||
|
||||
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
|
||||
__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) {
|
||||
extern __shared__ uint32_t smem_buffer[];
|
||||
|
||||
// Shapes and strides
|
||||
constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u);
|
||||
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(int));
|
||||
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
|
||||
const auto tma_aligned_mn = align<uint64_t>(mn, kNumTMAAlignedElems);
|
||||
|
||||
// Shift into the group
|
||||
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
|
||||
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * kNumPackedSFK;
|
||||
|
||||
// Load FP32 SFs
|
||||
DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size");
|
||||
const auto local_sf = reinterpret_cast<uint32_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
|
||||
const auto num_values = in_block_mn * SF_K;
|
||||
const auto num_uint4 = num_values / 4;
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) {
|
||||
const auto& [x, y, z, w] = __ldg(reinterpret_cast<uint4*>(local_sf) + i);
|
||||
st_shared(reinterpret_cast<uint4*>(smem_buffer) + i, x, y, z, w);
|
||||
}
|
||||
|
||||
// Fill unaligned values as well
|
||||
if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values)
|
||||
st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx));
|
||||
__syncthreads();
|
||||
|
||||
// Pack into UE8M0 and store
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) {
|
||||
const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN;
|
||||
|
||||
// Load shared memory
|
||||
uint32_t values[4];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; ++ j) {
|
||||
const auto sf_k_idx = sf_k_pack_idx * 4 + j;
|
||||
values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0;
|
||||
}
|
||||
|
||||
// Pack and store
|
||||
uint32_t packed = 0;
|
||||
packed |= (values[0] >> 23u);
|
||||
packed |= (values[1] >> 15u);
|
||||
packed |= (values[2] >> 7u);
|
||||
packed |= (values[3] << 1u);
|
||||
if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn)
|
||||
out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed;
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t kNumGroups, uint32_t kNumThreads,
|
||||
uint32_t BLOCK_MN, uint32_t BLOCK_PACKED_SF_K, bool kTransposed = true>
|
||||
__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
|
||||
const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) {
|
||||
// Always packing the K dimension
|
||||
// NOTES: should also assert `mn % 4 == 0` at launch
|
||||
DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)");
|
||||
DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes");
|
||||
DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes");
|
||||
|
||||
// Shapes and strides
|
||||
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
|
||||
const auto in_block_mn_uint4 = in_block_mn / 4;
|
||||
const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K);
|
||||
|
||||
// Shift into the right block along MN
|
||||
sf += blockIdx.x * BLOCK_MN;
|
||||
out += blockIdx.x * BLOCK_MN;
|
||||
|
||||
// Each warp is responsible for a packed row
|
||||
const auto warp_idx = threadIdx.x / 32;
|
||||
const auto lane_idx = get_lane_idx();
|
||||
const auto packed_sf_k_idx = static_cast<uint64_t>(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx;
|
||||
if (warp_idx >= in_block_packed_sf_k)
|
||||
return;
|
||||
|
||||
// Make an offset on the input
|
||||
uint32_t input_offset = 0;
|
||||
if constexpr (kNumGroups > 1) {
|
||||
// Load each group's size
|
||||
DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups");
|
||||
uint32_t group_ks[4];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; ++ i) {
|
||||
const auto group_idx = lane_idx * 4 + i;
|
||||
group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Make the offset
|
||||
sf_k = 0;
|
||||
auto sum_packed_sf_k = 0;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumGroups; ++ i) {
|
||||
const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4);
|
||||
sf_k += sf_k_in_group;
|
||||
sum_packed_sf_k += ceil_div(sf_k_in_group, 4u);
|
||||
if (packed_sf_k_idx < sum_packed_sf_k)
|
||||
break;
|
||||
if (const auto remainder = sf_k_in_group % 4; remainder > 0)
|
||||
input_offset += 4 - remainder;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) {
|
||||
// Load
|
||||
uint4 values[4];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; ++ j) {
|
||||
values[j] = make_uint4(0, 0, 0, 0);
|
||||
if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k)
|
||||
values[j] = __ldg(reinterpret_cast<uint4*>(sf + sf_k_idx * mn) + mn_idx);
|
||||
}
|
||||
|
||||
// Pack and store
|
||||
uint4 packed;
|
||||
packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u);
|
||||
packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u);
|
||||
packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u);
|
||||
packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u);
|
||||
reinterpret_cast<uint4*>(out + packed_sf_k_idx * mn)[mn_idx] = packed;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -1,163 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class GemmType {
|
||||
Normal,
|
||||
GroupedContiguous,
|
||||
GroupedMasked
|
||||
};
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
||||
template <GemmType kGemmType,
|
||||
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
|
||||
uint32_t kNum1DBlocksPerGroup = 16>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
uint32_t num_aligned_m_blocks;
|
||||
|
||||
// For normal GEMM
|
||||
// Maybe not used in the masked grouped GEMM
|
||||
uint32_t num_blocks;
|
||||
uint32_t num_blocks_in_group;
|
||||
bool is_peer_cta_alive = true;
|
||||
|
||||
// For grouped GEMM
|
||||
int* grouped_layout;
|
||||
|
||||
// Only used for masked layout
|
||||
uint32_t curr_group_idx, curr_cumsum;
|
||||
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m,
|
||||
int* grouped_layout = nullptr) {
|
||||
num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
} else if (kGemmType == GemmType::GroupedContiguous) {
|
||||
num_blocks = num_aligned_m_blocks * kNumNBlocks;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::GroupedMasked) {
|
||||
curr_group_idx = curr_cumsum = 0;
|
||||
this->grouped_layout = grouped_layout;
|
||||
}
|
||||
}
|
||||
|
||||
// ReSharper disable once CppNotAllPathsReturnValue
|
||||
__device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return true;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
|
||||
if (num_blocks_in_group == 1)
|
||||
return false;
|
||||
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
|
||||
return true;
|
||||
} else {
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type");
|
||||
if constexpr (kIsTMAMulticastOnA) {
|
||||
return true;
|
||||
} else {
|
||||
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
|
||||
return group_idx == peer_group_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& num_m_blocks, const uint32_t& block_idx,
|
||||
uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks;
|
||||
auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks;
|
||||
auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
|
||||
|
||||
// Fix unaligned TMA multicast
|
||||
if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) {
|
||||
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
|
||||
num_blocks_in_group = num_blocks_in_group ^ 1;
|
||||
} else {
|
||||
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
|
||||
first_block_idx += num_blocks_in_group ^ 1;
|
||||
num_blocks_in_group = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to final M/N block indices
|
||||
if constexpr (kIsTMAMulticastOnA) {
|
||||
m_block_idx = in_group_idx / num_blocks_in_group;
|
||||
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
|
||||
} else {
|
||||
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
|
||||
n_block_idx = in_group_idx / num_blocks_in_group;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kIgnoreGroupedForGroupedContiguous=true>
|
||||
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t& shape_dim, const uint32_t& block_size,
|
||||
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M));
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
return curr_group_idx * shape_dim + block_idx * block_size;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
|
||||
|
||||
if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
uint32_t num_m_blocks;
|
||||
while (true) {
|
||||
// End of the task
|
||||
if (curr_group_idx == kNumGroups)
|
||||
return false;
|
||||
|
||||
// Within the current group
|
||||
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
|
||||
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
|
||||
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
|
||||
break;
|
||||
|
||||
// Move to check the next group
|
||||
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
|
||||
} else {
|
||||
if (next_block_idx >= num_blocks)
|
||||
return false;
|
||||
|
||||
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
|
||||
is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass)
|
||||
num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass)
|
||||
(next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
|
||||
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -1,19 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
// TODO: move this function to other files
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) {
|
||||
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
||||
if (num_tma_multicast == 1) {
|
||||
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
} else if (cute::block_rank_in_cluster() == 0) {
|
||||
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -1,34 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
|
||||
asm volatile("trap;");
|
||||
}
|
||||
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_STATIC_ASSERT
|
||||
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ constexpr T constexpr_gcd(T a, T b) {
|
||||
return b == 0 ? a : constexpr_gcd(b, a % b);
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler
|
||||
from .runtime import Runtime
|
||||
@@ -1,284 +0,0 @@
|
||||
import functools
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
import cuda.bindings
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
|
||||
def hash_to_hex(s: str) -> str:
|
||||
md5 = hashlib.md5()
|
||||
md5.update(s.encode('utf-8'))
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_jit_include_dir() -> str:
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_deep_gemm_version() -> str:
|
||||
md5 = hashlib.md5()
|
||||
|
||||
# Update include directories
|
||||
include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm')
|
||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||
with open(os.path.join(include_dir, filename), 'rb') as f:
|
||||
md5.update(f.read())
|
||||
|
||||
# Update `interleave_ffma.py`
|
||||
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f:
|
||||
md5.update(f.read())
|
||||
return md5.hexdigest()[0:12]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
paths = []
|
||||
if os.getenv('DG_JIT_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_JIT_NVCC_COMPILER'))
|
||||
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
least_version_required = '12.3'
|
||||
version_pattern = re.compile(r'release (\d+\.\d+)')
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
command = [path, '--version']
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
match = version_pattern.search(result.stdout)
|
||||
version = match.group(1)
|
||||
assert match, f'Cannot get the version of NVCC compiler {path}'
|
||||
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
||||
return path, version
|
||||
raise RuntimeError('Cannot find any available NVCC compiler')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_default_user_dir():
|
||||
if 'DG_JIT_CACHE_DIR' in os.environ:
|
||||
path = os.getenv('DG_JIT_CACHE_DIR')
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
return os.path.join(os.path.expanduser('~'), '.deep_gemm')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_tmp_dir():
|
||||
return os.path.join(get_default_user_dir(), 'tmp')
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_cache_dir():
|
||||
return os.path.join(get_default_user_dir(), 'cache')
|
||||
|
||||
|
||||
def make_tmp_dir():
|
||||
tmp_dir = get_tmp_dir()
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
def put(path, data):
|
||||
# Write and do POSIX atomic replace
|
||||
tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}')
|
||||
with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f:
|
||||
f.write(data)
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
class Compiler:
|
||||
@classmethod
|
||||
def signature(cls) -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def flags() -> List[str]:
|
||||
cpp_standard = int(os.getenv('DG_JIT_OVERRIDE_CPP_STANDARD', 20))
|
||||
return [f'-std=c++{cpp_standard}',
|
||||
'--ptxas-options=--register-usage-level=10' +
|
||||
(',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,161,174,177,186,940']
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_JIT_DISABLE_FFMA_INTERLEAVE', 0))
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
|
||||
if cached_runtime is not None:
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return cached_runtime
|
||||
|
||||
# Compile into a temporary CU file
|
||||
os.makedirs(path, exist_ok=True)
|
||||
cubin_path = os.path.join(path, 'kernel.cubin')
|
||||
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
|
||||
|
||||
start_time = time.time()
|
||||
cls.compile(name, code, tmp_cubin_path)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_cubin_path)
|
||||
|
||||
# Atomic replace files
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime = runtime_cache.get(path, runtime_cls, name, kwargs, force_enable_cache=True)
|
||||
assert runtime is not None
|
||||
return runtime
|
||||
|
||||
|
||||
class NVCCCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
_, version = get_nvcc_compiler()
|
||||
major, minor = map(int, version.split('.'))
|
||||
return major, minor
|
||||
|
||||
@classmethod
|
||||
def signature(cls) -> str:
|
||||
return f'{get_nvcc_compiler()[0]}+{cls.__version__()}'
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
|
||||
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> None:
|
||||
# Write the code
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
src_path = os.path.join(path, 'kernel.cu')
|
||||
put(src_path, code)
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', target_path,
|
||||
*cls.flags()]
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}')
|
||||
assert False, f'Failed to compile {src_path}'
|
||||
|
||||
|
||||
class NVRTCCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
res, major, minor = nvrtc.nvrtcVersion()
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
# Failed to get the actual NVRTC version, use cuda-bindings version instead
|
||||
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
|
||||
return major, minor
|
||||
|
||||
@classmethod
|
||||
def signature(cls) -> str:
|
||||
return f'nvrtc+{cls.__version__()}'
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
|
||||
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')]
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'--gpu-architecture=sm_90a', '-default-device']
|
||||
# NOTES: PCH is vital for compilation speed
|
||||
if cls.__version__() >= (12, 8):
|
||||
flags += ['--pch']
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
flags += ['--pch-verbose=true']
|
||||
return flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> None:
|
||||
# Create program
|
||||
code_bytes = bytes(code, 'utf-8')
|
||||
result, program = nvrtc.nvrtcCreateProgram(
|
||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}'
|
||||
|
||||
# Compile
|
||||
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
|
||||
print(f'Compiling JIT runtime {name} with options: {options}')
|
||||
compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0]
|
||||
|
||||
# Print compiler log
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
result, log_size = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}'
|
||||
|
||||
log_bytes = bytes(log_size)
|
||||
result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}'
|
||||
print(f'Compiler log: {log_bytes.decode("utf-8")}')
|
||||
|
||||
# Exit if failed
|
||||
assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}'
|
||||
|
||||
# Create CUBIN
|
||||
result, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}'
|
||||
cubin_bytes = bytes(cubin_size)
|
||||
result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
|
||||
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}'
|
||||
|
||||
# Write into the file system
|
||||
put(target_path, cubin_bytes)
|
||||
|
||||
# Destroy handler
|
||||
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
|
||||
|
||||
|
||||
def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
|
||||
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
|
||||
return compiler_cls.build(name, code, runtime_cls, kwargs)
|
||||
@@ -1,137 +0,0 @@
|
||||
import argparse
|
||||
import mmap
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
|
||||
def run_cuobjdump(file_path):
|
||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
return result.stdout
|
||||
|
||||
|
||||
def extract_ffma(sass):
|
||||
lines = sass.splitlines()
|
||||
collected = []
|
||||
current = []
|
||||
|
||||
arch_name, func_name = 'N/A', 'N/A'
|
||||
skip_next_line = False
|
||||
for line in lines:
|
||||
if 'code for' in line:
|
||||
arch_name = line.lstrip().lstrip('code for ').rstrip()
|
||||
elif 'Function :' in line:
|
||||
func_name = line.lstrip().lstrip('Function :').rstrip()
|
||||
elif 'FFMA' in line:
|
||||
current.append(line)
|
||||
skip_next_line = True
|
||||
elif skip_next_line:
|
||||
current.append(line)
|
||||
skip_next_line = False
|
||||
else:
|
||||
if len(current) >= 16:
|
||||
assert len(current) % 2 == 0
|
||||
collected.append((f'{arch_name}::{func_name}', current))
|
||||
current = []
|
||||
|
||||
if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)):
|
||||
print(f'Found {len(collected)} FFMA segments')
|
||||
return collected
|
||||
|
||||
|
||||
def extract_hex_from_line(line):
|
||||
match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line)
|
||||
assert match
|
||||
return int(match.group(1), 16)
|
||||
|
||||
|
||||
def validate(m, offset, le_bytes, num_lines):
|
||||
assert len(le_bytes) == num_lines // 2
|
||||
assert m[offset:offset + 16] == le_bytes[0]
|
||||
for i in range(1, num_lines // 2):
|
||||
if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def parse_registers(line):
|
||||
line = re.sub(r'/\*.*?\*/', '', line)
|
||||
line = line.replace(';', '')
|
||||
tokens = line.strip().split(',')
|
||||
registers = []
|
||||
for token in tokens:
|
||||
token = token.strip()
|
||||
words = token.split()
|
||||
for word in words:
|
||||
if word.startswith('R'):
|
||||
reg = word.split('.')[0]
|
||||
registers.append(reg)
|
||||
return registers
|
||||
|
||||
|
||||
def modify_segment(m, name, ffma_lines):
|
||||
num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2
|
||||
assert num_lines % 2 == 0
|
||||
|
||||
le_bytes, new_le_bytes = [], []
|
||||
reused_list = []
|
||||
dst_reg_set = set()
|
||||
last_reused, last_dst_reg = False, ''
|
||||
num_changed = 0
|
||||
for i in range(num_lines // 2):
|
||||
dst_reg = parse_registers(ffma_lines[i * 2])[-2]
|
||||
low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1]
|
||||
low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line)
|
||||
le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
|
||||
reused = (high_hex & 0x0800000000000000) != 0
|
||||
if reused:
|
||||
is_first_occurred = dst_reg not in dst_reg_set
|
||||
if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
|
||||
# Modify the `reuse` and `yield` bits
|
||||
assert high_hex & 0x0800200000000000, f'{hex(high_hex)}'
|
||||
high_hex ^= 0x0800200000000000
|
||||
reused = False
|
||||
num_changed += 1
|
||||
else:
|
||||
reused_list.append(i)
|
||||
dst_reg_set.add(dst_reg)
|
||||
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
|
||||
last_reused, last_dst_reg = reused, dst_reg
|
||||
if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)):
|
||||
print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
|
||||
|
||||
# Find the offset
|
||||
offsets = []
|
||||
offset = m.find(le_bytes[0])
|
||||
while offset != -1:
|
||||
offsets.append(offset)
|
||||
offset = m.find(le_bytes[0], offset + 1)
|
||||
offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets))
|
||||
|
||||
# Replace with `new_le_bytes`
|
||||
for offset in offsets:
|
||||
for i in range(num_lines // 2):
|
||||
m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i]
|
||||
|
||||
|
||||
def process(path):
|
||||
if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)):
|
||||
print(f'Processing {path}')
|
||||
output = run_cuobjdump(path)
|
||||
segments = extract_ffma(output)
|
||||
with open(path, 'r+b') as f:
|
||||
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE)
|
||||
for segment in segments:
|
||||
modify_segment(mm, *segment)
|
||||
mm.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse')
|
||||
parser.add_argument('--so', help='Path to the SO file')
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args.so)
|
||||
@@ -1,105 +0,0 @@
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import torch
|
||||
import cuda.bindings.driver as cbd
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
def is_path_valid(path: str) -> bool:
|
||||
# Exists and is a directory
|
||||
if not os.path.exists(path) or not os.path.isdir(path):
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cubin']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
@staticmethod
|
||||
def generate(kwargs: Dict[str, Any]) -> str:
|
||||
raise NotImplemented
|
||||
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||
raise NotImplemented
|
||||
|
||||
def __call__(self, **kwargs) -> cbd.CUresult:
|
||||
# Load CUBIN
|
||||
if self.kernel is None:
|
||||
start_time = time.time_ns()
|
||||
|
||||
# Load CUBIN
|
||||
path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8')
|
||||
result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0)
|
||||
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}'
|
||||
|
||||
# Extract the kernel name
|
||||
# TODO: use `cuda-bindings` API to do this (requires at least 12.8)
|
||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
illegal_names = ['vprintf', '__instantiate_kernel', '__internal', '__assertfail']
|
||||
check_illegal = lambda line: any([name in line for name in illegal_names])
|
||||
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
|
||||
if line.startswith('STT_FUNC') and not check_illegal(line)]
|
||||
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
|
||||
|
||||
# Load kernel from the library
|
||||
result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8'))
|
||||
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}'
|
||||
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1e6
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
|
||||
|
||||
# noinspection PyArgumentList
|
||||
return self.launch(self.kernel, kwargs)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.lib is not None:
|
||||
res = cbd.cuLibraryUnload(self.lib)[0]
|
||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to unload library {self.path}: {res}')
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
|
||||
def __setitem__(self, path: str, runtime: Runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
|
||||
def get(self, path: str, runtime_cls: Type[Runtime],
|
||||
name: str = '', kwargs: Dict[str, Any] = None,
|
||||
force_enable_cache: bool = False) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
use_cache = force_enable_cache or not int(os.getenv('DG_JIT_DISABLE_CACHE', 0))
|
||||
if use_cache and os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
# Print heuristic for the first time
|
||||
if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))):
|
||||
simplified_kwargs = dict()
|
||||
for key, value in kwargs.items() if kwargs is not None else dict().items():
|
||||
value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value
|
||||
value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value
|
||||
simplified_kwargs[key] = value
|
||||
print(f'Put kernel {name} with {simplified_kwargs} into runtime cache')
|
||||
|
||||
runtime = runtime_cls(path)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
@@ -1,14 +0,0 @@
|
||||
from .gemm import gemm_fp8_fp8_bf16_nt
|
||||
from .m_grouped_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
)
|
||||
from .wgrad_gemm import (
|
||||
wgrad_gemm_fp8_fp8_fp32_nt,
|
||||
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt
|
||||
)
|
||||
from .utils import (
|
||||
ceil_div, set_num_sms, get_num_sms,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_m_alignment_for_contiguous_layout
|
||||
)
|
||||
@@ -1,242 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit import build
|
||||
from .runtime import (
|
||||
FP8GemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
|
||||
|
||||
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int,
|
||||
require_divisible: bool = False) -> bool:
|
||||
divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible
|
||||
return divisible and num_sms % num_tma_multicast == 0
|
||||
|
||||
|
||||
def get_swizzle_mode(block_n: int) -> int:
|
||||
elem_size = 2
|
||||
for mode_bytes in (128, 64, 32):
|
||||
if (block_n * elem_size) % mode_bytes == 0:
|
||||
return mode_bytes
|
||||
return 0
|
||||
|
||||
|
||||
def get_block_n_padding_for_smem_d(block_n: int) -> int:
|
||||
# NOTES: padding is for solving bank conflicts, but wastes shared memory space
|
||||
elem_size, requirement = 2, (4, 8)
|
||||
bank_stride = (block_n * elem_size) // 4
|
||||
padding = (requirement[0] - bank_stride) % requirement[1]
|
||||
return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size
|
||||
|
||||
|
||||
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128,
|
||||
is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]:
|
||||
assert block_k == 128
|
||||
|
||||
# Try swizzle first, as it does not waste shared memory
|
||||
swizzle_mode = get_swizzle_mode(block_n)
|
||||
block_n_padding = get_block_n_padding_for_smem_d(
|
||||
block_n) if swizzle_mode == 0 else 0
|
||||
|
||||
# NOTES: `scales_b` in a total manner or per-stage manner
|
||||
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0
|
||||
smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0
|
||||
smem_barrier = (num_stages + int(is_wgrad)) * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
smem_size += smem_d
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
smem_size += num_stages * smem_scales_b_per_stage
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
smem_size += smem_barrier
|
||||
|
||||
# Swizzle and padding are not compatible
|
||||
assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1
|
||||
|
||||
return smem_size, swizzle_mode, block_n_padding
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False,
|
||||
is_fp32_out: bool = False, is_wgrad: bool = False) -> \
|
||||
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]:
|
||||
if not is_grouped_contiguous:
|
||||
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
|
||||
|
||||
# Avoid bank conflicts for FP32 output
|
||||
if is_fp32_out:
|
||||
block_ns = [x for x in block_ns if x % 16 == 8]
|
||||
|
||||
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
||||
|
||||
# Decide block sizes by waves
|
||||
best_block_m, best_block_n = None, None
|
||||
for block_m in block_ms:
|
||||
# NOTES: the block sizes cannot be too large, so at least one dim less than 128
|
||||
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
|
||||
success = False
|
||||
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||
if best_block_m is None or best_block_n is None:
|
||||
success = True
|
||||
elif num_waves < best_num_waves:
|
||||
success = True
|
||||
elif num_waves == best_num_waves:
|
||||
# Check last wave utilization
|
||||
util = get_last_wave_util(block_m, block_n)
|
||||
best_util = get_last_wave_util(best_block_m, best_block_n)
|
||||
success = util > best_util
|
||||
if util == best_util:
|
||||
# Case 1: same `block_m`, smaller `block_n` (wasted)
|
||||
success |= block_m == best_block_m and block_n < best_block_n
|
||||
# Case 2: same `block_n`, smaller `block_m` (wasted)
|
||||
success |= block_n == best_block_n and block_m < best_block_m
|
||||
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
|
||||
success |= block_m != best_block_m and block_n > best_block_n
|
||||
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
||||
assert best_block_m is not None and best_block_n is not None
|
||||
|
||||
# Always pick the longest one
|
||||
# NOTES: for double B scales, the best number of stages may be reduced
|
||||
best_num_stages, best_smem_config, sm90_capacity = None, None, 232448
|
||||
stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (8, 7, 6, 5, 4, 3, 2, 1)))
|
||||
if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
|
||||
# Unrolling both stages and `num_former_iters` will cause large code size
|
||||
stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1)))
|
||||
for num_stages in stage_candidates:
|
||||
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad)
|
||||
if best_smem_config[0] <= sm90_capacity:
|
||||
best_num_stages = num_stages
|
||||
break
|
||||
assert best_smem_config is not None
|
||||
assert best_num_stages is not None
|
||||
|
||||
# Decide the number of TMA multicasts and whether broadcast on A
|
||||
best_tma_multicast_config = (1, True)
|
||||
|
||||
# Try to multicast on the larger block side first
|
||||
# NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even
|
||||
is_multicast_legal = {
|
||||
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked),
|
||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked,
|
||||
}
|
||||
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
|
||||
if m >= 512 and is_multicast_legal[i]:
|
||||
best_tma_multicast_config = (2, i == 'A')
|
||||
break
|
||||
|
||||
# Recompute the minimal number of SMs required
|
||||
# NOTES: less L2 cache usage and less GPU frequency drop
|
||||
num_waves = get_num_waves(best_block_m, best_block_n)
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||
assert num_min_sms <= num_sms
|
||||
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||
|
||||
|
||||
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor) -> None:
|
||||
"""
|
||||
Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
|
||||
Requirements:
|
||||
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
|
||||
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m, n]`, representing the result.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and k > 0
|
||||
assert lhs_scales.shape == (m, ceil_div(k, 128))
|
||||
assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128))
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||
|
||||
# LHS scales must be transposed for TMA loads, but not for RHS scales
|
||||
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
# K must be aligned to 128
|
||||
aligned_k = ceil_div(k, 128) * 128
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
|
||||
tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
|
||||
tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
|
||||
|
||||
kwargs = {
|
||||
# Templated arguments
|
||||
'GEMM_TYPE': GemmType.Normal,
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m, 'N': n, 'K': aligned_k,
|
||||
'NUM_GROUPS': 1,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
# Runtime arguments
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
'DEVICE_INDEX': out.device.index
|
||||
}
|
||||
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||
runtime(**kwargs)
|
||||
@@ -1,205 +0,0 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from ..jit import build
|
||||
from .gemm import get_best_configs
|
||||
from .runtime import (
|
||||
FP8GemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||
from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
"""
|
||||
Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
|
||||
Requirements:
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`,
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
|
||||
m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
|
||||
`m_indices[i]` records the group which the i-th row of the LHS belongs to,
|
||||
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
num_groups, n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
m__ = m_indices.numel()
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ == m__ and k == k_ and n == n_
|
||||
assert lhs_scales.shape == (m, ceil_div(k, 128))
|
||||
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert m_indices.dtype == torch.int32
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
return
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedContiguous, rhs, n, k, k, block_n, block_k, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedContiguous, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
kwargs = {
|
||||
# Templated arguments
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m, 'N': n, 'K': k,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': GemmType.GroupedContiguous,
|
||||
# Runtime arguments
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': m_indices,
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
'DEVICE_INDEX': out.device.index
|
||||
}
|
||||
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||
runtime(**kwargs)
|
||||
|
||||
|
||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
"""
|
||||
Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
|
||||
Requirements:
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
|
||||
this function will do a transposing with a set of slow PyTorch operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
The second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
in the i-th group.
|
||||
expected_m: a value hint (which is a value on CPU) for the M expectation of each batch,
|
||||
correctly setting this value may lead to better performance.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
num_groups, m, k = lhs.shape
|
||||
num_groups_, n, k_ = rhs.shape
|
||||
num_groups__, m_, n_ = out.shape
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
# Type and shape checks
|
||||
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||
assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128))
|
||||
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert masked_m.dtype == torch.int32
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
|
||||
|
||||
# Extra checks for TMA store
|
||||
if num_groups > 1 and m > block_m:
|
||||
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
|
||||
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedMasked, lhs, m, k, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedMasked, rhs, n, k, k, block_n, block_k, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedMasked, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
|
||||
|
||||
kwargs = {
|
||||
# Templated arguments
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m, 'N': n, 'K': k,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||
'SWIZZLE_D_MODE': smem_config[1],
|
||||
'BLOCK_N_PADDING': smem_config[2],
|
||||
'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
'GEMM_TYPE': GemmType.GroupedMasked,
|
||||
# Runtime arguments
|
||||
'SCALES_B': rhs_scales,
|
||||
'GROUPED_LAYOUT': masked_m,
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
'DEVICE_INDEX': out.device.index
|
||||
}
|
||||
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||
runtime(**kwargs)
|
||||
@@ -1,318 +0,0 @@
|
||||
import ctypes
|
||||
import os
|
||||
import enum
|
||||
import torch
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from .utils import get_tma_aligned_size
|
||||
from ..jit.runtime import Runtime
|
||||
|
||||
|
||||
class GemmType(enum.Enum):
|
||||
Normal = 0
|
||||
GroupedContiguous = 1
|
||||
GroupedMasked = 2
|
||||
|
||||
def __str__(self) -> str:
|
||||
return {
|
||||
0: 'Normal',
|
||||
1: 'GroupedContiguous',
|
||||
2: 'GroupedMasked',
|
||||
}[self.value]
|
||||
|
||||
|
||||
tmap_type_map: Dict[Any, str] = {
|
||||
torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
||||
torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
|
||||
torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
|
||||
torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
||||
torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
|
||||
torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
|
||||
torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
|
||||
torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
|
||||
torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
}
|
||||
|
||||
swizzle_type_map = {
|
||||
0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
|
||||
64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
|
||||
128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
}
|
||||
|
||||
|
||||
def get_num_math_warpgroups(block_m: int) -> int:
|
||||
return 1 if block_m == 64 else 2
|
||||
|
||||
|
||||
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
|
||||
assert num_math_threads_per_group == 128, 'Only support 128 threads per math group'
|
||||
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
|
||||
|
||||
|
||||
def make_2d_tma_copy_desc(t: torch.Tensor,
|
||||
gmem_dims: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_outer_stride: cbd.cuuint64_t,
|
||||
smem_dims: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
|
||||
swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap:
|
||||
tensor_dtype = tmap_type_map[t.dtype]
|
||||
res, tensor_map = cbd.cuTensorMapEncodeTiled(
|
||||
tensor_dtype,
|
||||
2,
|
||||
t.data_ptr(),
|
||||
gmem_dims,
|
||||
(gmem_outer_stride,),
|
||||
smem_dims,
|
||||
(cbd.cuuint32_t(1), cbd.cuuint32_t(1)),
|
||||
cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
swizzle_type,
|
||||
cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
|
||||
)
|
||||
|
||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to encode tensor map: {res}')
|
||||
return tensor_map
|
||||
|
||||
|
||||
def make_2d_tma_desc(t: torch.Tensor,
|
||||
gmem_inner_dim: int, gmem_outer_dim: int, gmem_outer_stride: int,
|
||||
smem_inner_dim: int, smem_outer_dim: int,
|
||||
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
|
||||
gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim))
|
||||
smem_dim = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim))
|
||||
return make_2d_tma_copy_desc(t, gmem_dim, cbd.cuuint64_t(gmem_outer_stride * t.element_size()), smem_dim, swizzle_type)
|
||||
|
||||
|
||||
def make_2d_tma_a_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_m: int, shape_k: int, m_stride: int,
|
||||
block_m: int, block_k: int,
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
return make_2d_tma_desc(t,
|
||||
shape_k, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
|
||||
block_k, block_m)
|
||||
|
||||
|
||||
def make_2d_tma_b_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_n: int, shape_k: int, n_stride: int,
|
||||
block_n: int, block_k: int,
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
return make_2d_tma_desc(t,
|
||||
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), n_stride,
|
||||
block_k, block_n)
|
||||
|
||||
|
||||
def make_2d_tma_d_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_m: int, shape_n: int, m_stride: int,
|
||||
block_m: int, block_n: int,
|
||||
num_groups: int,
|
||||
swizzle_mode: int) -> cbd.CUtensorMap:
|
||||
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
|
||||
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
return make_2d_tma_desc(t,
|
||||
shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
|
||||
block_n if swizzle_mode == 0 else swizzle_mode // t.element_size(), block_m,
|
||||
swizzle_type_map[swizzle_mode])
|
||||
|
||||
|
||||
def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_mn: int, shape_k: int,
|
||||
block_mn: int, block_k: int,
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
# Make TMA aligned to 16 bytes
|
||||
shape_mn = get_tma_aligned_size(shape_mn, t.element_size())
|
||||
return make_2d_tma_desc(t,
|
||||
shape_mn, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_mn,
|
||||
block_mn, 1,
|
||||
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
class FP8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path)
|
||||
|
||||
@staticmethod
|
||||
def generate(kwargs: Dict[str, Any]) -> str:
|
||||
code = f'''
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include <deep_gemm/fp8_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['BLOCK_N_PADDING']},
|
||||
{kwargs['SWIZZLE_D_MODE']},
|
||||
{kwargs['NUM_GROUPS']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>);
|
||||
}};
|
||||
'''
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Generated FP8 GEMM code:\n{code}')
|
||||
return code
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
|
||||
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
|
||||
|
||||
attr_val = cbd.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
|
||||
attr_val.clusterDim.y = 1
|
||||
attr_val.clusterDim.z = 1
|
||||
attr = cbd.CUlaunchAttribute()
|
||||
attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
|
||||
attr.value = attr_val
|
||||
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.numAttrs = 1
|
||||
config.attrs = [attr]
|
||||
config.gridDimX = kwargs['NUM_SMS']
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.sharedMemBytes = kwargs['SMEM_SIZE']
|
||||
config.hStream = kwargs['STREAM']
|
||||
|
||||
arg_values = (
|
||||
kwargs['SCALES_B'].data_ptr(),
|
||||
kwargs['GROUPED_LAYOUT'].data_ptr(),
|
||||
kwargs['M'],
|
||||
kwargs['TENSOR_MAP_A'],
|
||||
kwargs['TENSOR_MAP_B'],
|
||||
kwargs['TENSOR_MAP_SCALES_A'],
|
||||
kwargs['TENSOR_MAP_D'],
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
|
||||
|
||||
|
||||
class FP8WGradGemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path)
|
||||
|
||||
@staticmethod
|
||||
def generate(kwargs: Dict[str, Any]) -> str:
|
||||
code = f'''
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include <deep_gemm/fp8_wgrad_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&fp8_wgrad_gemm_kernel<
|
||||
{kwargs['M']},
|
||||
{kwargs['N']},
|
||||
{kwargs['BLOCK_M']},
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['NUM_LAST_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}
|
||||
>);
|
||||
}};
|
||||
'''
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Generated FP8 WGrad GEMM code:\n{code}')
|
||||
return code
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
|
||||
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
|
||||
|
||||
attr_val = cbd.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
|
||||
attr_val.clusterDim.y = 1
|
||||
attr_val.clusterDim.z = 1
|
||||
attr = cbd.CUlaunchAttribute()
|
||||
attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
|
||||
attr.value = attr_val
|
||||
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.numAttrs = 1
|
||||
config.attrs = [attr]
|
||||
config.gridDimX = kwargs['NUM_SMS']
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.sharedMemBytes = kwargs['SMEM_SIZE']
|
||||
config.hStream = kwargs['STREAM']
|
||||
|
||||
arg_values = (
|
||||
kwargs['K'],
|
||||
kwargs['TENSOR_MAP_A'],
|
||||
kwargs['TENSOR_MAP_B'],
|
||||
kwargs['TENSOR_MAP_SCALES_A'],
|
||||
kwargs['TENSOR_MAP_SCALES_B'],
|
||||
kwargs['TENSOR_MAP_D'],
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_uint32,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
|
||||
@@ -1,109 +0,0 @@
|
||||
import torch
|
||||
|
||||
_num_sms = None
|
||||
|
||||
|
||||
def set_num_sms(num_sms: int) -> None:
|
||||
"""
|
||||
Set the maximum SM count for all GEMM kernels to use.
|
||||
|
||||
Arguments:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
"""
|
||||
Get the current maximum limit of SM count for all GEMM kernels to use.
|
||||
If the count is never specified, the function will return the number of device SMs.
|
||||
|
||||
Returns:
|
||||
Current maximum limit of SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
if _num_sms is None:
|
||||
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
"""
|
||||
Perform ceiling division of two integers.
|
||||
|
||||
Args:
|
||||
x: the dividend.
|
||||
y: the divisor.
|
||||
|
||||
Returns:
|
||||
The result of the ceiling division.
|
||||
"""
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def get_m_alignment_for_contiguous_layout():
|
||||
"""
|
||||
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
|
||||
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
|
||||
with GEMM block shape.
|
||||
|
||||
Returns:
|
||||
Group-level alignment requirement for grouped contiguous layout, which is always 128.
|
||||
"""
|
||||
return 128
|
||||
|
||||
|
||||
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
"""
|
||||
Global memory address of TMA must be 16-byte aligned.
|
||||
Since we use column-major layout for the LHS scaling tensor,
|
||||
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
|
||||
|
||||
Arguments:
|
||||
x: original M-axis shape of the LHS scaling tensor.
|
||||
element_size: element size of the LHS scaling tensor.
|
||||
|
||||
Returns:
|
||||
M-axis shape of the LHS scaling tensor after padding.
|
||||
"""
|
||||
tma_alignment_bytes = 16
|
||||
assert tma_alignment_bytes % element_size == 0
|
||||
alignment = tma_alignment_bytes // element_size
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
Arguments:
|
||||
x: usually the LHS scaling tensor in GEMM.
|
||||
|
||||
Returns:
|
||||
The LHS scaling tensor of TMA-aligned transposed format.
|
||||
"""
|
||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
||||
assert x.dim() in (2, 3)
|
||||
remove_dim = False
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
if x.dim() == 2:
|
||||
if x.stride(0) == 1 and x.stride(1) == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
# Normal layout requires transposing
|
||||
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
aligned_x[:, :m, :] = x
|
||||
aligned_x = aligned_x[:, :m, :]
|
||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||
@@ -1,158 +0,0 @@
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
|
||||
from ..jit import build
|
||||
from .runtime import (
|
||||
FP8WGradGemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_desc)
|
||||
from .gemm import get_best_configs
|
||||
from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
|
||||
|
||||
|
||||
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor):
|
||||
"""
|
||||
Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
||||
Results will be accumulated into the output tensor.
|
||||
|
||||
Requirements:
|
||||
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
|
||||
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling and RHS scaling tensor require a TMA-aligned transposed format.
|
||||
If your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations.
|
||||
|
||||
Arguments:
|
||||
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`.
|
||||
out: the FP32 output tensor of shape `[m, n]`, which will be accumulated.
|
||||
"""
|
||||
lhs, lhs_scales = lhs
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert n > 0 and m > 0
|
||||
assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m)
|
||||
assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n)
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert out.dtype == torch.float
|
||||
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
|
||||
|
||||
# LHS and RHS scales must be transposed for TMA load
|
||||
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
|
||||
def get_valid_scales(scales: torch.Tensor, mn: int):
|
||||
if scales.shape == (ceil_div(k, 128), mn):
|
||||
# For k-grouped GEMMs
|
||||
scales = scales.permute(1, 0)
|
||||
assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn
|
||||
else:
|
||||
scales = get_col_major_tma_aligned_tensor(scales)
|
||||
return scales
|
||||
|
||||
lhs_scales = get_valid_scales(lhs_scales, m)
|
||||
rhs_scales = get_valid_scales(rhs_scales, n)
|
||||
|
||||
# Do nothing if `k` is zero
|
||||
if k == 0:
|
||||
return
|
||||
|
||||
# K must be aligned to 128
|
||||
aligned_k = ceil_div(k, 128) * 128
|
||||
|
||||
# Auto-tuning with compilation
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
|
||||
num_last_stages = ceil_div(k, 128) % num_stages
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
|
||||
tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
|
||||
tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
|
||||
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
|
||||
tensor_map_scales_b = make_2d_tma_scales_desc(GemmType.Normal, rhs_scales, n, k, block_n, block_k, 1)
|
||||
|
||||
kwargs = {
|
||||
# Templated arguments
|
||||
'GEMM_TYPE': GemmType.Normal,
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m, 'N': n, 'K': aligned_k,
|
||||
'NUM_GROUPS': 1,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||
'NUM_STAGES': num_stages,
|
||||
'NUM_LAST_STAGES': num_last_stages,
|
||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
|
||||
# Runtime arguments
|
||||
'NUM_SMS': num_sms,
|
||||
'SMEM_SIZE': smem_config[0],
|
||||
'TENSOR_MAP_A': tensor_map_a,
|
||||
'TENSOR_MAP_B': tensor_map_b,
|
||||
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
|
||||
'TENSOR_MAP_SCALES_B': tensor_map_scales_b,
|
||||
'TENSOR_MAP_D': tensor_map_d,
|
||||
'STREAM': torch.cuda.current_stream().cuda_stream,
|
||||
'DEVICE_INDEX': out.device.index
|
||||
}
|
||||
|
||||
# Generate, build and run the kernel
|
||||
code = FP8WGradGemmRuntime.generate(kwargs)
|
||||
runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs)
|
||||
runtime(**kwargs)
|
||||
|
||||
|
||||
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
batch_sizes: List[int]):
|
||||
"""
|
||||
Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
|
||||
Results will be accumulated into the output tensor.
|
||||
|
||||
Requirements:
|
||||
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
|
||||
Each batch's LHS, RHS, and output tensors must be contiguous.
|
||||
The RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling and RHS scaling tensors require a TMA-aligned transposed format.
|
||||
|
||||
Arguments:
|
||||
lhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
|
||||
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
|
||||
The second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`,
|
||||
representing the per-128-channel scaling factors.
|
||||
rhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
|
||||
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
|
||||
The second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
|
||||
representing the per-128-channel scaling factors.
|
||||
out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
|
||||
batch_sizes: A list of integers specifying the k-dimension for each batch.
|
||||
"""
|
||||
lhs, lhs_scales = lhs[0].view(-1), lhs[1]
|
||||
rhs, rhs_scales = rhs[0].view(-1), rhs[1]
|
||||
num_batches, m, n = out.shape
|
||||
|
||||
lhs_offset, rhs_offset, scales_offset = 0, 0, 0
|
||||
|
||||
for i in range(num_batches):
|
||||
k = batch_sizes[i]
|
||||
lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
|
||||
rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
|
||||
lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
|
||||
rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
|
||||
wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i])
|
||||
|
||||
lhs_offset += m * k
|
||||
rhs_offset += n * k
|
||||
scales_offset += ceil_div(k, 128)
|
||||
3
deep_gemm/testing/__init__.py
Normal file
3
deep_gemm/testing/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from . import bench, numeric
|
||||
from .bench import *
|
||||
from .numeric import *
|
||||
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
@@ -31,7 +29,7 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
return start_event.elapsed_time(end_event) / num_tests / 1e3
|
||||
|
||||
|
||||
class empty_suppress:
|
||||
@@ -77,8 +75,9 @@ class suppress_stdout_stderr:
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True,
|
||||
def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
||||
suppress_kineto_output: bool = False,
|
||||
trace_path: str = None, flush_l2: bool = True,
|
||||
with_multiple_kernels: bool = False):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0))
|
||||
@@ -96,12 +95,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
|
||||
with profiler:
|
||||
for i in range(2):
|
||||
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
||||
if barrier_comm_profiling:
|
||||
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
lhs @ rhs
|
||||
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
|
||||
for _ in range(num_tests):
|
||||
if flush_l2:
|
||||
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
||||
@@ -116,7 +109,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tupled = isinstance(kernel_names, tuple)
|
||||
is_tuple = isinstance(kernel_names, tuple)
|
||||
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
||||
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
@@ -145,21 +138,4 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
break
|
||||
kernel_times.append(total_time / total_num)
|
||||
|
||||
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
||||
|
||||
|
||||
def calc_diff(x, y):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def count_bytes(tensors):
|
||||
total = 0
|
||||
for t in tensors:
|
||||
if isinstance(t, tuple):
|
||||
total += count_bytes(t)
|
||||
else:
|
||||
total += t.numel() * t.element_size()
|
||||
return total
|
||||
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
||||
19
deep_gemm/testing/numeric.py
Normal file
19
deep_gemm/testing/numeric.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def count_bytes(*tensors):
|
||||
total = 0
|
||||
for t in tensors:
|
||||
if isinstance(t, (tuple, list)):
|
||||
total += count_bytes(*t)
|
||||
elif t is not None:
|
||||
total += t.numel() * t.element_size()
|
||||
return total
|
||||
3
deep_gemm/utils/__init__.py
Normal file
3
deep_gemm/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from . import math, layout
|
||||
from .layout import *
|
||||
from .math import *
|
||||
11
deep_gemm/utils/layout.py
Normal file
11
deep_gemm/utils/layout.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from deep_gemm_cpp import (
|
||||
get_tma_aligned_size,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
get_mn_major_tma_aligned_tensor,
|
||||
get_mn_major_tma_aligned_packed_ue8m0_tensor,
|
||||
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
|
||||
)
|
||||
|
||||
# Some alias
|
||||
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
||||
get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
||||
48
deep_gemm/utils/math.py
Normal file
48
deep_gemm/utils/math.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def align(x: int, y: int) -> int:
|
||||
return ceil_div(x, y) * y
|
||||
|
||||
|
||||
def ceil_to_ue8m0(x: torch.Tensor):
|
||||
assert x.view(-1).amax().item() > 0
|
||||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf
|
||||
|
||||
|
||||
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(0) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(-1, 128, n)
|
||||
x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2))
|
||||
Reference in New Issue
Block a user