use ceil_div in cutlass block scaling shape check (#17918)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#include <torch/all.h>
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
||||
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
@@ -28,29 +29,46 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
using GroupShape = std::array<int64_t, 2>;
|
||||
auto make_group_shape = [](torch::Tensor const& x,
|
||||
torch::Tensor const& s) -> GroupShape {
|
||||
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
||||
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
||||
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
||||
};
|
||||
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
||||
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 100) {
|
||||
TORCH_CHECK(
|
||||
a.size(0) == a_scales.size(0) &&
|
||||
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
||||
"a_scale_group_shape must be [1, 128].");
|
||||
TORCH_CHECK(
|
||||
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
||||
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
||||
"b_scale_group_shape must be [128, 128].");
|
||||
} else {
|
||||
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
|
||||
// kernel, or introducing ceil_div to the load_init() of mainloop.
|
||||
using GroupShape = std::array<int64_t, 2>;
|
||||
auto make_group_shape = [](torch::Tensor const& x,
|
||||
torch::Tensor const& s) -> GroupShape {
|
||||
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
||||
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
||||
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
||||
};
|
||||
|
||||
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
||||
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
||||
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
||||
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
||||
|
||||
// 1x128 per-token group scales for activations
|
||||
// 128x128 blockwise scales for weights
|
||||
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
||||
b_scale_group_shape == GroupShape{128, 128} &&
|
||||
a.dtype() == torch::kFloat8_e4m3fn &&
|
||||
b.dtype() == torch::kFloat8_e4m3fn),
|
||||
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
||||
"a_scale_group_shape must be [1, 128]. Got: [",
|
||||
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
||||
"]\n"
|
||||
"b_scale_group_shape must be [128, 128]. Got: [",
|
||||
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
||||
}
|
||||
|
||||
// 1x128 per-token group scales for activations
|
||||
// 128x128 blockwise scales for weights
|
||||
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
||||
b_scale_group_shape == GroupShape{128, 128} &&
|
||||
a.dtype() == torch::kFloat8_e4m3fn &&
|
||||
b.dtype() == torch::kFloat8_e4m3fn),
|
||||
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
||||
"a_scale_group_shape must be [1, 128]. Got: [",
|
||||
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
||||
"]\n"
|
||||
"b_scale_group_shape must be [128, 128]. Got: [",
|
||||
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
blockwise_func(c, a, b, a_scales, b_scales);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user