diff --git a/csrc/jit_kernels/impls/smxx_layout.hpp b/csrc/jit_kernels/impls/smxx_layout.hpp index 7aa0638..eda8c1b 100644 --- a/csrc/jit_kernels/impls/smxx_layout.hpp +++ b/csrc/jit_kernels/impls/smxx_layout.hpp @@ -121,6 +121,7 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const to const auto& options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8); auto padded = torch::zeros({num_groups, aligned_mn, aligned_k}, options); + // ReSharper disable once CppExpressionWithoutSideEffects padded.slice(1, 0, mn).slice(2, 0, k).copy_(ue8m0_tensor); padded = padded.view(-1).view(torch::kInt32).view({num_groups, aligned_mn, aligned_k / 4}); @@ -140,8 +141,10 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt)); // Launch the kernel if (batched_sf.is_contiguous()) { + // Fallback to slow PyTorch impl for non-supported cases if ((mn * sf_k) % 4 != 0 and num_groups > 1) return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); + constexpr int block_mn = 48; constexpr int num_threads = 512; const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = { @@ -157,8 +160,10 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code); TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args); } else { + // Fallback to slow PyTorch impl for non-supported cases if (mn % 4 != 0 or num_groups > 1) return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); + DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1); DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);