Fix smxx layout assertion (#141)

* Fix assertion error in smxx_layout.hpp for mn % 4 != 0 cases

* Fix assertion error in smxx_layout.hpp for mn % 4 != 0 cases

* Align submodule files

* Fix assertion error in smxx_layout.hpp for mn % 4 != 0 cases

* fix(smxx_layout): support mn%4!=0 and num_groups>1 via torch

* fix(smxx_layout): support mn%4!=0 and num_groups>1 via torch

* fix: correct logic for entering get_mn_major_tma_aligned_packed_ue8m0_tensor_torch
This commit is contained in:
LJC00118
2025-07-30 10:36:54 +08:00
committed by GitHub
parent a581263f06
commit 6bc75b549e

View File

@@ -108,16 +108,40 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf;
}
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) {
const auto& sf_reshaped = (sf.dim() == 2) ? sf.unsqueeze(0) : sf;
// First, convert into UE8M0 `uint8_t`
const auto& ue8m0_tensor = sf_reshaped.view(torch::kInt32).bitwise_right_shift(23).to(torch::kUInt8);
// Second, make padded packed tensors
const auto& [num_groups, mn, k] = get_shape<3>(sf_reshaped);
const auto& aligned_mn = get_tma_aligned_size(mn, 4);
const auto& aligned_k = align(k, 4);
const auto& options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8);
auto padded = torch::zeros({num_groups, aligned_mn, aligned_k}, options);
padded.slice(1, 0, mn).slice(2, 0, k).copy_(ue8m0_tensor);
padded = padded.view(-1).view(torch::kInt32).view({num_groups, aligned_mn, aligned_k / 4});
// Finally, transpose
auto out = torch::empty_strided({num_groups, aligned_mn, aligned_k / 4},
{aligned_mn * (aligned_k / 4), 1, aligned_mn},
at::TensorOptions().device(sf.device()).dtype(torch::kInt32));
out = out.copy_(padded).slice(1, 0, mn);
return (sf.dim() == 2) ? out.squeeze(0) : out;
}
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) {
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
const auto& packed_sf_k = ceil_div(sf_k, 4);
const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k},
{packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn},
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
DG_HOST_ASSERT(num_groups == 1 or (mn * sf_k) % 4 == 0);
// Launch the kernel
if (batched_sf.is_contiguous()) {
if ((mn * sf_k) % 4 != 0 and num_groups > 1)
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
constexpr int block_mn = 48;
constexpr int num_threads = 512;
const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = {
@@ -133,6 +157,8 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T
const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
} else {
if (mn % 4 != 0 or num_groups > 1)
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1);
DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);