[Kernel] Initial Machete W4A8 support + Refactors (#9855)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2024-11-18 14:59:29 -05:00
committed by GitHub
parent c2170a5b39
commit 96d999fbe8
28 changed files with 2616 additions and 1694 deletions

View File

@@ -6,31 +6,49 @@
namespace machete {
template <typename TileShapeNKL, typename ElementB, typename BInTensor,
typename BTiledOutTensor>
static __global__ void prepack_B_kernel(BInTensor B_in,
BTiledOutTensor B_tiled_out) {
auto tB_in = local_tile(B_in, TileShapeNKL{},
make_coord(blockIdx.x, blockIdx.y, blockIdx.z));
auto tB_out = B_tiled_out(make_coord(_, _),
make_coord(blockIdx.x, blockIdx.y), blockIdx.z);
template <int threads, typename PrepackedLayoutB, typename BInTensor,
typename ElementB>
static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
auto constexpr block_size =
Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
auto constexpr eles_per_thread = Int<block_size / threads>{};
static_assert(block_size % threads == 0,
"block_size must be divisible by the number of threads");
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, ElementB>{},
Layout<Shape<_4, _32>, Stride<_32, _1>>{},
Layout<Shape<_1, _2>>{});
// Which pre-packed are we responsible for
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
auto tB_in = local_tile(
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
blk_coord);
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
// Find the start offset in the output for this pre-packed block
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
Tensor thr_tile_S = thr_copy.partition_S(tB_in);
Tensor thr_tile_D = thr_copy.partition_D(tB_out);
// Tensor representing a 1:1 mapping to the output space in 1D
auto tB_out_linear =
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
make_layout(make_shape(block_size)));
// Mapping from output space (1D) to input space
auto tB_in_linear = make_tensor(
tB_in.data(),
tB_in.layout()
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
.with_shape(make_shape(block_size)));
// Tile for this specific thread (could have used a TiledCopy but these work
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
// we are also not that concerned with performance for this kernel)
auto thr_tB_in_linear =
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
auto thr_tB_out_linear =
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
// Construct a register-backed Tensor with the same shape as each thread's
// partition
auto fragment = make_tensor<ElementB>(shape(thr_tile_D));
auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(tiled_copy, thr_tile_S, fragment);
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tile_D);
copy(thr_tB_in_linear, fragment);
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
}
template <typename PrepackedLayoutB, typename InLayout>
@@ -44,18 +62,15 @@ static void prepack_B_template(
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0);
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{});
auto L_tiles = size<2>(B_layout);
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
auto B_tiled_out =
make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset);
prepack_B_kernel<TileShapeNKL, typename PrepackedLayoutB::ElementB>
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_tiled_out);
prepack_B_kernel<128, PrepackedLayoutB>
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
}
}; // namespace machete