Multiple updates and refactorings (#280)
This commit is contained in:
@@ -53,11 +53,10 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
const auto& [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
|
||||
a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, disable_ue8m0_cast);
|
||||
DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
@@ -66,7 +65,9 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
|
||||
const auto& major_sfb = get_major_type_ab(sfb);
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, major_sfb, compiled_dims, epilogue_type);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
|
||||
// NOTES: Only granularity 128 and FP8 are exposed in the API
|
||||
sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k,
|
||||
128, 128, major_a, major_b, compiled_dims, epilogue_type);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
@@ -229,8 +230,8 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
);
|
||||
|
||||
// Allocate output
|
||||
constexpr int num_math_warp_groups = 4;
|
||||
const auto& aligned_max_context_len = align(max_context_len, num_math_warp_groups * block_kv);
|
||||
constexpr int split_kv = 256;
|
||||
const auto& aligned_max_context_len = align(max_context_len, split_kv);
|
||||
auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q.options().dtype(torch::kFloat));
|
||||
logits = logits.slice(-1, 0, max_context_len);
|
||||
|
||||
@@ -239,7 +240,7 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
if (arch_major == 9 or arch_major == 10) {
|
||||
smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta,
|
||||
batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
|
||||
kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, num_math_warp_groups);
|
||||
kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, split_kv);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
@@ -251,7 +252,8 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
}
|
||||
return logits;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
||||
|
||||
@@ -139,7 +139,7 @@ static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims) {
|
||||
// Shape must be `[B, M, K] @ [B, N, K].T`
|
||||
const auto& major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
|
||||
@@ -163,15 +163,16 @@ static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
return;
|
||||
|
||||
// Transform scaling factors
|
||||
const auto& transformed_sfa = layout::transform_sf_into_required_layout(sfa, m, k, recipe, batch_size, true, false);
|
||||
const auto& transformed_sfb = layout::transform_sf_into_required_layout(sfb, n, k, recipe, batch_size, false, false);
|
||||
const auto& [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
|
||||
sfa, sfb, m, n, k, recipe, std::nullopt, std::nullopt, batch_size, batch_size, false);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
const auto arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 10) {
|
||||
sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
const auto& major_sfb = get_major_type_ab(sfb);
|
||||
sm90_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, major_sfb, compiled_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,6 +183,7 @@ static void fp8_einsum(const std::string& expr,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::tuple<int, int, int>& recipe) {
|
||||
// Some hardcoded Einstein sum kernels
|
||||
const auto arch_major = device_runtime->get_arch_major();
|
||||
if (expr == "bhr,hdr->bhd") {
|
||||
// Permute dims to satisfy the order of (batch_size, m, n, k)
|
||||
// (batch_size, m, n, k): (h, b, d, r)
|
||||
@@ -190,7 +192,7 @@ static void fp8_einsum(const std::string& expr,
|
||||
const auto& perm_d = d.permute({1, 0, 2});
|
||||
const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
|
||||
fp8_bmm(perm_a, perm_sfa, b.first, b.second, perm_d, perm_c, recipe, "nk");
|
||||
} else if (expr == "bhd,hdr->bhr") {
|
||||
} else if (expr == "bhd,hdr->bhr" and arch_major == 10) {
|
||||
// (batch_size, m, n, k): (h, b, r, d)
|
||||
const auto& perm_a = a.first.permute({1, 0, 2});
|
||||
const auto& perm_sfa = a.second.permute({1, 0, 2});
|
||||
@@ -199,7 +201,7 @@ static void fp8_einsum(const std::string& expr,
|
||||
const auto& perm_d = d.permute({1, 0, 2});
|
||||
const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
|
||||
fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, perm_d, perm_c, recipe, "nk");
|
||||
} else if (expr == "bhd,bhr->hdr") {
|
||||
} else if (expr == "bhd,bhr->hdr" and arch_major == 10) {
|
||||
// (batch_size, m, n, k): (h, d, r, b)
|
||||
const auto& perm_a = a.first.permute({1, 2, 0});
|
||||
const auto& perm_sfa = a.second.permute({1, 2, 0});
|
||||
|
||||
@@ -46,13 +46,16 @@ static bool early_return(const int& m, const int &n, const int& k,
|
||||
}
|
||||
|
||||
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
||||
static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
|
||||
static void fp8_fp4_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
std::optional<std::tuple<int, int>> recipe_a,
|
||||
std::optional<std::tuple<int, int>> recipe_b,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
@@ -65,12 +68,11 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a.first);
|
||||
const auto& [n , k_] = get_shape<2>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
const auto arch_major = device_runtime->get_arch_major();
|
||||
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
|
||||
const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major);
|
||||
const auto [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
|
||||
// Early return for trivial cases
|
||||
@@ -78,88 +80,104 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
|
||||
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, std::nullopt, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
if (std::get<1>(recipe.value()) == 1) {
|
||||
const int gran_n = recipe.has_value() ? std::get<1>(recipe.value()) : std::get<0>(recipe_b.value());
|
||||
if (gran_n == 1) {
|
||||
sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
const auto& major_sfb = get_major_type_ab(sfb);
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims);
|
||||
}
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b,
|
||||
major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
static void fp8_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
static void fp8_fp4_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::optional<std::tuple<int, int>>& recipe_a,
|
||||
const std::optional<std::tuple<int, int>>& recipe_b,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_fp4_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
static void fp8_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
|
||||
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
static void fp8_fp4_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::optional<std::tuple<int, int>>& recipe_a,
|
||||
const std::optional<std::tuple<int, int>>& recipe_b,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
|
||||
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
static void fp8_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
static void fp8_fp4_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::optional<std::tuple<int, int>>& recipe_a,
|
||||
const std::optional<std::tuple<int, int>>& recipe_b,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
|
||||
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
static void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& grouped_layout,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
std::optional<std::tuple<int, int>> recipe_a,
|
||||
std::optional<std::tuple<int, int>> recipe_b,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast,
|
||||
const bool& use_psum_layout,
|
||||
const std::optional<int>& expected_m_for_psum_layout) {
|
||||
// Shape must be `[M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
if (fp8_requires_k_major())
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(m_indices.is_contiguous());
|
||||
DG_HOST_ASSERT(grouped_layout.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m, k] = get_shape<2>(a.first);
|
||||
const auto& [num_groups, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
const auto& m__ = static_cast<int>(m_indices.numel());
|
||||
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
||||
const auto arch_major = device_runtime->get_arch_major();
|
||||
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
|
||||
const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
|
||||
const auto [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt);
|
||||
|
||||
// Layout checks
|
||||
if (use_psum_layout) {
|
||||
const auto& [num_groups_] = get_shape<1>(grouped_layout);
|
||||
DG_HOST_ASSERT(num_groups == num_groups_);
|
||||
} else {
|
||||
const auto& [m__] = get_shape<1>(grouped_layout);
|
||||
DG_HOST_ASSERT(m == m__);
|
||||
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
|
||||
}
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
@@ -169,44 +187,48 @@ static void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
|
||||
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, num_groups, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
const auto& major_sfb = get_major_type_ab(sfb);
|
||||
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
DG_HOST_ASSERT(not use_psum_layout);
|
||||
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, grouped_layout,
|
||||
num_groups, m, n, k, major_a, major_b, major_sfb, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout,
|
||||
num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b,
|
||||
compiled_dims, use_psum_layout, expected_m_for_psum_layout);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
static void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
|
||||
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
static void m_grouped_fp8_fp4_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& grouped_layout,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::optional<std::tuple<int, int>>& recipe_a,
|
||||
const std::optional<std::tuple<int, int>>& recipe_b,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast,
|
||||
const bool& use_psum_layout) {
|
||||
m_grouped_fp8_fp4_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
|
||||
d, grouped_layout, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast, use_psum_layout, std::nullopt);
|
||||
}
|
||||
|
||||
static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& expected_m,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& expected_m,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
std::optional<std::tuple<int, int>> recipe_a,
|
||||
std::optional<std::tuple<int, int>> recipe_b,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[G, M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
@@ -214,15 +236,14 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
|
||||
DG_HOST_ASSERT(masked_m.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [num_groups, m, k] = get_shape<3>(a.first);
|
||||
const auto& [num_groups_, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
|
||||
const auto& num_groups___ = static_cast<int>(masked_m.numel());
|
||||
const auto arch_major = device_runtime->get_arch_major();
|
||||
const auto [num_groups , m , k ] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major);
|
||||
const auto [num_groups_ , n , k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
|
||||
const auto [num_groups__, m_, n_] = get_shape<3>(d);
|
||||
const auto num_groups___ = static_cast<int>(masked_m.numel());
|
||||
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
|
||||
|
||||
@@ -230,20 +251,18 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Transform scaling factors
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
|
||||
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, num_groups, num_groups, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
const auto& major_sfb = get_major_type_ab(sfb);
|
||||
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, gran_k_a, gran_k_b,
|
||||
major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
@@ -262,9 +281,10 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torc
|
||||
|
||||
// Shape checks
|
||||
const auto& [num_groups, m, n] = get_shape<3>(d);
|
||||
const auto& [_, m_] = get_shape<2>(a.first);
|
||||
const auto& [__, n_] = get_shape<2>(b.first);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_);
|
||||
const auto& [sum_k_ , m_] = get_shape<2>(a.first);
|
||||
const auto& [sum_k__, n_] = get_shape<2>(b.first);
|
||||
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);
|
||||
|
||||
// Contiguity checks
|
||||
DG_HOST_ASSERT(a.first.is_contiguous());
|
||||
@@ -283,8 +303,8 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torc
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 10) {
|
||||
fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
|
||||
sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
@@ -305,9 +325,7 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
|
||||
const auto& [num_groups, m, n] = get_shape<3>(d);
|
||||
const auto& sum_mk = a.first.numel();
|
||||
const auto& sum_nk = b.first.numel();
|
||||
int sum_k = 0;
|
||||
for (const auto& k: ks)
|
||||
sum_k += k;
|
||||
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
|
||||
DG_HOST_ASSERT(sum_mk == static_cast<int64_t>(sum_k) * m);
|
||||
DG_HOST_ASSERT(sum_nk == static_cast<int64_t>(sum_k) * n);
|
||||
|
||||
@@ -334,7 +352,7 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer,
|
||||
sm90_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer,
|
||||
cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
@@ -404,25 +422,36 @@ static void bf16_gemm_tt(const torch::Tensor& a,
|
||||
}
|
||||
|
||||
static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const torch::Tensor& m_indices,
|
||||
const std::string& compiled_dims) {
|
||||
const torch::Tensor& d, const torch::Tensor& grouped_layout,
|
||||
const std::string& compiled_dims,
|
||||
const bool& use_psum_layout,
|
||||
const std::optional<int>& expected_m_for_psum_layout) {
|
||||
// Shape must be `[M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a);
|
||||
const auto& major_b = get_major_type_ab(b);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(m_indices.is_contiguous());
|
||||
DG_HOST_ASSERT(grouped_layout.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m, k] = get_shape<2>(a);
|
||||
const auto& [num_groups, n, k_] = get_shape<3>(b);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
const auto& m__ = static_cast<int>(m_indices.numel());
|
||||
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt);
|
||||
|
||||
// Layout checks
|
||||
if (use_psum_layout) {
|
||||
const auto& [num_groups_] = get_shape<1>(grouped_layout);
|
||||
DG_HOST_ASSERT(num_groups == num_groups_);
|
||||
} else {
|
||||
const auto& [m__] = get_shape<1>(grouped_layout);
|
||||
DG_HOST_ASSERT(m == m__);
|
||||
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
|
||||
}
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
@@ -434,21 +463,24 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_m_grouped_bf16_gemm_contiguous(a, b, d, m_indices,
|
||||
DG_HOST_ASSERT(not use_psum_layout);
|
||||
sm90_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10) {
|
||||
sm100_m_grouped_bf16_gemm_contiguous(a, b, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
sm100_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims,
|
||||
use_psum_layout, expected_m_for_psum_layout);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void m_grouped_bf16_gemm_nn_contiguous(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const torch::Tensor& m_indices,
|
||||
const std::string& compiled_dims) {
|
||||
const torch::Tensor& d, const torch::Tensor& grouped_layout,
|
||||
const std::string& compiled_dims,
|
||||
const bool& use_psum_layout) {
|
||||
m_grouped_bf16_gemm_nt_contiguous(a, b.transpose(1, 2),
|
||||
d, m_indices, compiled_dims);
|
||||
d, grouped_layout, compiled_dims, use_psum_layout, std::nullopt);
|
||||
}
|
||||
|
||||
static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b,
|
||||
@@ -498,9 +530,10 @@ static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a,
|
||||
const std::string& compiled_dims) {
|
||||
// Shape checks
|
||||
const auto& [num_groups, m, n] = get_shape<3>(d);
|
||||
const auto& [_, m_] = get_shape<2>(a);
|
||||
const auto& [__, n_] = get_shape<2>(b);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_);
|
||||
const auto& [sum_k_ , m_] = get_shape<2>(a);
|
||||
const auto& [sum_k__, n_] = get_shape<2>(b);
|
||||
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);
|
||||
|
||||
// Contiguity checks
|
||||
DG_HOST_ASSERT(a.is_contiguous());
|
||||
@@ -563,38 +596,50 @@ static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
|
||||
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
||||
// FP8 GEMMs
|
||||
m.def("fp8_gemm_nt", &fp8_gemm_nt,
|
||||
// FP8 FP4 GEMMs
|
||||
m.def("fp8_fp4_gemm_nt", &fp8_fp4_gemm_nt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_nn", &fp8_gemm_nn,
|
||||
m.def("fp8_fp4_gemm_nn", &fp8_fp4_gemm_nn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tn", &fp8_gemm_tn,
|
||||
m.def("fp8_fp4_gemm_tn", &fp8_fp4_gemm_tn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tt", &fp8_gemm_tt,
|
||||
m.def("fp8_fp4_gemm_tt", &fp8_fp4_gemm_tt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked,
|
||||
m.def("m_grouped_fp8_fp4_gemm_nt_contiguous", &m_grouped_fp8_fp4_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
|
||||
py::arg("recipe") = std::nullopt,
|
||||
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false,
|
||||
py::arg("use_psum_layout") = false,
|
||||
py::arg("expected_m_for_psum_layout") = std::nullopt);
|
||||
m.def("m_grouped_fp8_fp4_gemm_nn_contiguous", &m_grouped_fp8_fp4_gemm_nn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
|
||||
py::arg("recipe") = std::nullopt,
|
||||
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false,
|
||||
py::arg("use_psum_layout") = false);
|
||||
m.def("m_grouped_fp8_fp4_gemm_nt_masked", &m_grouped_fp8_fp4_gemm_nt_masked,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
||||
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
|
||||
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
|
||||
@@ -606,6 +651,15 @@ static void register_apis(pybind11::module_& m) {
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
|
||||
// FP8 GEMM alias names
|
||||
m.attr("fp8_gemm_nt") = m.attr("fp8_fp4_gemm_nt");
|
||||
m.attr("fp8_gemm_nn") = m.attr("fp8_fp4_gemm_nn");
|
||||
m.attr("fp8_gemm_tn") = m.attr("fp8_fp4_gemm_tn");
|
||||
m.attr("fp8_gemm_tt") = m.attr("fp8_fp4_gemm_tt");
|
||||
m.attr("m_grouped_fp8_gemm_nt_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nt_contiguous");
|
||||
m.attr("m_grouped_fp8_gemm_nn_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nn_contiguous");
|
||||
m.attr("m_grouped_fp8_gemm_nt_masked") = m.attr("m_grouped_fp8_fp4_gemm_nt_masked");
|
||||
#endif
|
||||
|
||||
#if DG_TENSORMAP_COMPATIBLE
|
||||
@@ -627,11 +681,14 @@ static void register_apis(pybind11::module_& m) {
|
||||
py::arg("c") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn");
|
||||
m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("compiled_dims") = "nk");
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("use_psum_layout") = false,
|
||||
py::arg("expected_m_for_psum_layout") = std::nullopt);
|
||||
m.def("m_grouped_bf16_gemm_nn_contiguous", &m_grouped_bf16_gemm_nn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("compiled_dims") = "nk");
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("use_psum_layout") = false);
|
||||
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
||||
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
|
||||
|
||||
70
csrc/apis/hyperconnection.hpp
Normal file
70
csrc/apis/hyperconnection.hpp
Normal file
@@ -0,0 +1,70 @@
|
||||
#pragma once
|
||||
|
||||
#include "../utils/compatibility.hpp"
|
||||
|
||||
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
||||
#include "../jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp"
|
||||
#include "../jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp"
|
||||
#endif
|
||||
|
||||
namespace deep_gemm::hyperconnection {
|
||||
|
||||
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
||||
static void tf32_hc_prenorm_gemm(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& sqr_sum,
|
||||
const std::optional<int>& num_splits) {
|
||||
// A and B must be K-major, D must be N-major
|
||||
DG_HOST_ASSERT(get_major_type_ab(a) == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(get_major_type_ab(b) == cute::UMMA::Major::K);
|
||||
check_major_type_cd(d);
|
||||
|
||||
// S must be contiguous
|
||||
DG_HOST_ASSERT(sqr_sum.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m, k ] = get_shape<2>(a);
|
||||
const auto& [n, k_] = get_shape<2>(b);
|
||||
if (num_splits.has_value()) {
|
||||
const auto& [num_splits_, m_, n_] = get_shape<3>(d);
|
||||
const auto& [num_splits__, m__] = get_shape<2>(sqr_sum);
|
||||
DG_HOST_ASSERT(num_splits.value() == num_splits_ and num_splits.value() == num_splits__ and num_splits.value() >= 1);
|
||||
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
||||
} else {
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
const auto& [m__] = get_shape<1>(sqr_sum);
|
||||
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
||||
}
|
||||
DG_HOST_ASSERT(n > 0 and k > 0);
|
||||
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(b.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(sqr_sum.scalar_type() == torch::kFloat);
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1);
|
||||
} else if (arch_major == 10) {
|
||||
sm100_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
||||
m.def("tf32_hc_prenorm_gemm", &tf32_hc_prenorm_gemm,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("sqr_sum"),
|
||||
py::arg("num_splits") = std::nullopt);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::hyperconnection
|
||||
@@ -1,20 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include "../utils/layout.hpp"
|
||||
#include "../utils/compatibility.hpp"
|
||||
|
||||
#if DG_TENSORMAP_COMPATIBLE
|
||||
#include "../jit_kernels/impls/smxx_layout.hpp"
|
||||
#endif
|
||||
|
||||
namespace deep_gemm::layout {
|
||||
|
||||
#if DG_TENSORMAP_COMPATIBLE
|
||||
static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const int& mn, const int& k,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::optional<std::tuple<int, int>>& recipe_ab,
|
||||
const std::optional<int>& num_groups,
|
||||
const bool& is_sfa,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
|
||||
const auto& gran_k = std::get<2>(recipe);
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
|
||||
int gran_mn, gran_k;
|
||||
if (recipe.has_value()) {
|
||||
DG_HOST_ASSERT(not recipe_ab.has_value());
|
||||
gran_mn = is_sfa ? std::get<0>(recipe.value()) : std::get<1>(recipe.value());
|
||||
gran_k = std::get<2>(recipe.value());
|
||||
} else {
|
||||
DG_HOST_ASSERT(recipe_ab.has_value());
|
||||
std::tie(gran_mn, gran_k) = recipe_ab.value();
|
||||
}
|
||||
|
||||
// Pre-transform checks
|
||||
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups);
|
||||
|
||||
@@ -22,30 +36,44 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return get_mn_major_tma_aligned_tensor(sf);
|
||||
|
||||
// (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf);
|
||||
}
|
||||
|
||||
// (FP32, 128, 128) on SM90: no need to transform, check SFB requirements
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat);
|
||||
|
||||
// (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) {
|
||||
// (FP32, x, gran_k) on SM100: transform to (INT, 1, gran_k), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and (gran_k == 32 or gran_k == 128) and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128));
|
||||
const auto& broadcasted = gran_mn == 1 ? sf :
|
||||
sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn));
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
|
||||
}
|
||||
|
||||
// (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10)
|
||||
// (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10)
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown SF transformation");
|
||||
}
|
||||
|
||||
static std::tuple<torch::Tensor, torch::Tensor, int, int> transform_sf_pair_into_required_layout(
|
||||
const torch::Tensor& sfa, const torch::Tensor& sfb,
|
||||
const int& m, const int& n, const int& k,
|
||||
std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::optional<std::tuple<int, int>>& recipe_a,
|
||||
const std::optional<std::tuple<int, int>>& recipe_b,
|
||||
const std::optional<int>& num_groups_a,
|
||||
const std::optional<int>& num_groups_b,
|
||||
const bool& disable_ue8m0_cast = false) {
|
||||
DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value());
|
||||
if (not recipe_a.has_value() and not recipe.has_value())
|
||||
recipe = get_default_recipe(sfa.scalar_type(), sfb.scalar_type());
|
||||
const auto transformed_sfa = transform_sf_into_required_layout(sfa, m, k, recipe, recipe_a, num_groups_a, true, disable_ue8m0_cast);
|
||||
const auto transformed_sfb = transform_sf_into_required_layout(sfb, n, k, recipe, recipe_b, num_groups_b, false, disable_ue8m0_cast);
|
||||
const int gran_k_a = recipe_a.has_value() ? std::get<1>(recipe_a.value()) : std::get<2>(recipe.value());
|
||||
const int gran_k_b = recipe_b.has_value() ? std::get<1>(recipe_b.value()) : std::get<2>(recipe.value());
|
||||
return std::make_tuple(transformed_sfa, transformed_sfb, gran_k_a, gran_k_b);
|
||||
}
|
||||
|
||||
static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
@@ -69,17 +97,24 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
|
||||
DG_HOST_UNREACHABLE("Unknown cases");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
|
||||
#if DG_TENSORMAP_COMPATIBLE
|
||||
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
|
||||
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
|
||||
py::arg("sf"), py::arg("mn"), py::arg("k"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("recipe_ab") = std::nullopt,
|
||||
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
|
||||
m.def("get_tma_aligned_size", &get_tma_aligned_size);
|
||||
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
|
||||
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
|
||||
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
#endif
|
||||
|
||||
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::layout
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#if DG_TENSORMAP_COMPATIBLE
|
||||
#include "../jit/compiler.hpp"
|
||||
#endif
|
||||
#include "../jit/device_runtime.hpp"
|
||||
|
||||
namespace deep_gemm::runtime {
|
||||
@@ -18,10 +20,11 @@ static void register_apis(pybind11::module_& m) {
|
||||
m.def("get_tc_util", [&]() {
|
||||
return device_runtime->get_tc_util();
|
||||
});
|
||||
|
||||
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
|
||||
#if DG_TENSORMAP_COMPATIBLE
|
||||
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
|
||||
KernelRuntime::prepare_init(cuda_home_path_by_python);
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,10 @@
|
||||
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
|
||||
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
|
||||
|
||||
// Hyperconnection kernels
|
||||
#include <deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh>
|
||||
#include <deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh>
|
||||
|
||||
// Layout kernels
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
#include <deep_gemm/impls/smxx_clean_logits.cuh>
|
||||
|
||||
@@ -24,6 +24,7 @@ public:
|
||||
static std::filesystem::path library_include_path;
|
||||
static std::filesystem::path cuda_home;
|
||||
static std::string library_version;
|
||||
static std::filesystem::path cuobjdump_path;
|
||||
|
||||
static std::string get_library_version() {
|
||||
std::vector<char> buffer;
|
||||
@@ -45,6 +46,7 @@ public:
|
||||
Compiler::library_include_path = Compiler::library_root_path / "include";
|
||||
Compiler::cuda_home = cuda_home_path_by_python;
|
||||
Compiler::library_version = get_library_version();
|
||||
Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump";
|
||||
}
|
||||
|
||||
std::string signature, flags;
|
||||
@@ -56,6 +58,7 @@ public:
|
||||
DG_HOST_ASSERT(not library_include_path.empty());
|
||||
DG_HOST_ASSERT(not cuda_home.empty());
|
||||
DG_HOST_ASSERT(not library_version.empty());
|
||||
DG_HOST_ASSERT(not cuobjdump_path.empty());
|
||||
|
||||
// Cache settings
|
||||
cache_dir_path = std::filesystem::path(get_env<std::string>("HOME")) / ".deep_gemm";
|
||||
@@ -108,25 +111,57 @@ public:
|
||||
|
||||
// Compile into a temporary CUBIN
|
||||
const auto tmp_cubin_path = get_tmp_file_path();
|
||||
compile(code, dir_path, tmp_cubin_path);
|
||||
if (get_env<int>("DG_JIT_DUMP_ASM") or get_env<int>("DG_JIT_DUMP_PTX")) {
|
||||
// Dump PTX if needed
|
||||
const auto tmp_ptx_path = get_tmp_file_path();
|
||||
compile(code, dir_path, tmp_cubin_path, tmp_ptx_path);
|
||||
|
||||
// Replace into the cache directory
|
||||
std::filesystem::rename(tmp_ptx_path, dir_path / "kernel.ptx");
|
||||
} else {
|
||||
compile(code, dir_path, tmp_cubin_path);
|
||||
}
|
||||
|
||||
// Replace into the cache directory
|
||||
make_dirs(dir_path);
|
||||
std::filesystem::rename(tmp_cubin_path, dir_path / "kernel.cubin");
|
||||
const auto cubin_path = dir_path / "kernel.cubin";
|
||||
std::filesystem::rename(tmp_cubin_path, cubin_path);
|
||||
|
||||
// Disassemble if needed
|
||||
if (get_env<int>("DG_JIT_DUMP_ASM") or get_env<int>("DG_JIT_DUMP_SASS")) {
|
||||
// Dump into a temporary SASS
|
||||
const auto tmp_sass_path = get_tmp_file_path();
|
||||
disassemble(cubin_path, tmp_sass_path);
|
||||
|
||||
// Replace into the current directory
|
||||
std::filesystem::rename(tmp_sass_path, dir_path / "kernel.sass");
|
||||
}
|
||||
|
||||
// Put into the runtime cache
|
||||
const auto& runtime = kernel_runtime_cache->get(dir_path);
|
||||
const auto runtime = kernel_runtime_cache->get(dir_path);
|
||||
DG_HOST_ASSERT(runtime != nullptr);
|
||||
return runtime;
|
||||
}
|
||||
|
||||
virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0;
|
||||
static void disassemble(const std::filesystem::path &cubin_path, const std::filesystem::path &sass_path) {
|
||||
// Disassemble the CUBIN file to SASS
|
||||
const auto command = fmt::format("{} --dump-sass {} > {}", cuobjdump_path.c_str(), cubin_path.c_str(), sass_path.c_str());
|
||||
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
|
||||
printf("Running cuobjdump command: %s\n", command.c_str());
|
||||
const auto [return_code, output] = call_external_command(command);
|
||||
if (return_code != 0) {
|
||||
printf("cuobjdump failed: %s\n", output.c_str());
|
||||
DG_HOST_ASSERT(false and "cuobjdump failed");
|
||||
}
|
||||
}
|
||||
|
||||
virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path, const std::optional<std::filesystem::path> &ptx_path = std::nullopt) const = 0;
|
||||
};
|
||||
|
||||
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path);
|
||||
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path);
|
||||
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home);
|
||||
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_version);
|
||||
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuobjdump_path);
|
||||
|
||||
class NVCCCompiler final: public Compiler {
|
||||
std::filesystem::path nvcc_path;
|
||||
@@ -164,17 +199,19 @@ public:
|
||||
const auto& arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9);
|
||||
flags = fmt::format("{} -I{} --gpu-architecture=sm_{} "
|
||||
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
|
||||
"-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda",
|
||||
"-O3 --expt-relaxed-constexpr --expt-extended-lambda",
|
||||
flags, library_include_path.c_str(), arch);
|
||||
}
|
||||
|
||||
void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
|
||||
void compile(const std::string &code, const std::filesystem::path& dir_path,
|
||||
const std::filesystem::path &cubin_path,
|
||||
const std::optional<std::filesystem::path> &ptx_path) const override {
|
||||
// Write the code into the cache directory
|
||||
const auto& code_path = dir_path / "kernel.cu";
|
||||
put(code_path, code);
|
||||
|
||||
// Compile
|
||||
const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
|
||||
const auto& command = fmt::format("{} {} -cubin -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
|
||||
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
|
||||
printf("Running NVCC command: %s\n", command.c_str());
|
||||
const auto& [return_code, output] = call_external_command(command);
|
||||
@@ -183,6 +220,18 @@ public:
|
||||
DG_HOST_ASSERT(false and "NVCC compilation failed");
|
||||
}
|
||||
|
||||
// Compile to PTX if needed
|
||||
if (ptx_path.has_value()) {
|
||||
const auto ptx_command = fmt::format("{} {} -ptx -o {} {}", nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags);
|
||||
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
|
||||
printf("Running NVCC PTX command: %s\n", ptx_command.c_str());
|
||||
const auto [ptx_return_code, ptx_output] = call_external_command(ptx_command);
|
||||
if (ptx_return_code != 0) {
|
||||
printf("NVCC PTX compilation failed: %s\n", ptx_output.c_str());
|
||||
DG_HOST_ASSERT(false and "NVCC PTX compilation failed");
|
||||
}
|
||||
}
|
||||
|
||||
// Check local memory usage
|
||||
if (get_env("DG_JIT_PTXAS_CHECK", 0))
|
||||
DG_HOST_ASSERT(not std::regex_search(output, std::regex(R"(Local memory used)")));
|
||||
@@ -219,11 +268,13 @@ public:
|
||||
// Override the compiler flags
|
||||
// Only NVRTC >= 12.9 supports arch-specific family suffix
|
||||
const auto& arch = device_runtime->get_arch(false, major > 12 or minor >= 9);
|
||||
flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {}",
|
||||
flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {} --device-int128",
|
||||
flags, include_dirs, arch, pch_flags);
|
||||
}
|
||||
|
||||
void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
|
||||
void compile(const std::string &code, const std::filesystem::path& dir_path,
|
||||
const std::filesystem::path &cubin_path,
|
||||
const std::optional<std::filesystem::path> &ptx_path) const override {
|
||||
// Write the code into the cache directory
|
||||
const auto& code_path = dir_path / "kernel.cu";
|
||||
put(code_path, code);
|
||||
@@ -266,6 +317,17 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
if (ptx_path.has_value()) {
|
||||
// Get PTX size and data if needed
|
||||
size_t ptx_size;
|
||||
DG_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size));
|
||||
std::string ptx_data(ptx_size, '\0');
|
||||
DG_NVRTC_CHECK(nvrtcGetPTX(program, ptx_data.data()));
|
||||
|
||||
// Write into the file system
|
||||
put(ptx_path.value(), ptx_data);
|
||||
}
|
||||
|
||||
// Get CUBIN size and data
|
||||
size_t cubin_size;
|
||||
DG_NVRTC_CHECK(nvrtcGetCUBINSize(program, &cubin_size));
|
||||
|
||||
@@ -17,19 +17,7 @@ class DeviceRuntime {
|
||||
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
|
||||
|
||||
public:
|
||||
#if TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 3)
|
||||
// For PyTorch 2.3+, share the PyTorch cuBLASLt handle
|
||||
DeviceRuntime() = default;
|
||||
|
||||
static cublasLtHandle_t get_cublaslt_handle() {
|
||||
return at::cuda::getCurrentCUDABlasLtHandle();
|
||||
}
|
||||
|
||||
static torch::Tensor get_cublaslt_workspace() {
|
||||
return torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA));
|
||||
}
|
||||
#else
|
||||
// Otherwise, create the cuBLASLt handle ourselves
|
||||
// Create the cuBLASLt handle ourselves
|
||||
cublasLtHandle_t cublaslt_handle{};
|
||||
std::shared_ptr<torch::Tensor> cublaslt_workspace;
|
||||
|
||||
@@ -49,7 +37,6 @@ public:
|
||||
torch::Tensor get_cublaslt_workspace() const {
|
||||
return *cublaslt_workspace;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<cudaDeviceProp> get_prop() {
|
||||
if (cached_prop == nullptr) {
|
||||
|
||||
@@ -40,10 +40,7 @@ DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleLoad);
|
||||
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleUnload);
|
||||
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleGetFunction);
|
||||
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx);
|
||||
|
||||
#if DG_TENSORMAP_COMPATIBLE
|
||||
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled);
|
||||
#endif
|
||||
|
||||
#if CUDART_VERSION >= 12080 and defined(DG_JIT_USE_RUNTIME_API)
|
||||
|
||||
@@ -166,7 +163,6 @@ static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle&
|
||||
void *ptr_args[] = { &args... };
|
||||
return lazy_cuLaunchKernelEx(&config, kernel, ptr_args, nullptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@@ -59,7 +59,8 @@ struct GemmConfig {
|
||||
// Templated configs
|
||||
GemmType gemm_type;
|
||||
KernelType kernel_type;
|
||||
at::ScalarType ab_dtype, cd_dtype;
|
||||
MmaKind mma_kind;
|
||||
at::ScalarType a_dtype, b_dtype, cd_dtype;
|
||||
cute::UMMA::Major major_a;
|
||||
cute::UMMA::Major major_b;
|
||||
bool with_accumulation;
|
||||
@@ -99,9 +100,9 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne
|
||||
const int& m, const int& n, const int& k,
|
||||
const int& block_m, const int& block_n, const int& block_k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
|
||||
const int& num_stages, const MulticastConfig& multicast_config) {
|
||||
const int& ab_elem_size = static_cast<int>(c10::elementSize(ab_dtype));
|
||||
const int& ab_elem_size = static_cast<int>(get_element_size(mma_kind));
|
||||
const int& cd_elem_size = static_cast<int>(c10::elementSize(cd_dtype));
|
||||
|
||||
const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m);
|
||||
@@ -119,7 +120,7 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne
|
||||
|
||||
// SF shared memory
|
||||
const auto& [smem_sfa_per_stage, smem_sfb_per_stage] =
|
||||
ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, ab_dtype, cd_dtype);
|
||||
ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, mma_kind, cd_dtype);
|
||||
const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k);
|
||||
|
||||
// M-barriers and tensor memory pointers
|
||||
@@ -151,21 +152,35 @@ template <typename ArchSpec>
|
||||
static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type,
|
||||
const int& m, const int& n, const int& k, const int& num_groups,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const at::ScalarType& a_dtype, const at::ScalarType& b_dtype,
|
||||
const at::ScalarType& cd_dtype,
|
||||
const bool& with_accumulation, const int& num_sms) {
|
||||
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16);
|
||||
const auto mma_kind = (a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4);
|
||||
if (mma_kind == MmaKind::BF16) {
|
||||
DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16);
|
||||
} else {
|
||||
DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4);
|
||||
DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4);
|
||||
}
|
||||
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
|
||||
|
||||
// Select M/N block sizes
|
||||
auto block_ms = ArchSpec::get_block_m_candidates(kernel_type, major_a, m);
|
||||
if (gemm_type == GemmType::MGroupedContiguous)
|
||||
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
|
||||
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
|
||||
block_ms = std::vector{64, 128};
|
||||
const auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype);
|
||||
if (gemm_type == GemmType::MGroupedMasked or gemm_type == GemmType::MGroupedContiguousWithPsumLayout)
|
||||
block_ms = std::vector{64, 128}; // Exclude 256 for performance
|
||||
auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype);
|
||||
|
||||
// NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B
|
||||
// TODO: Optimize it
|
||||
if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN)
|
||||
block_ms = std::vector{128};
|
||||
if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN)
|
||||
block_ns = std::vector{128};
|
||||
|
||||
// K block size is selected in a fixed manner
|
||||
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
|
||||
const auto& block_k = (mma_kind == MmaKind::BF16 ? 64 : 128);
|
||||
|
||||
// Some util functions
|
||||
const auto& get_num_blocks = [=](const int& block_m, const int& block_n) {
|
||||
@@ -186,7 +201,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
for (const auto& block_n: block_ns) {
|
||||
const int& num_waves = get_num_waves(block_m, block_n);
|
||||
const auto& last_util = get_last_wave_util(block_m, block_n);
|
||||
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, m, n, k, block_m, block_n, block_k))
|
||||
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, mma_kind, cd_dtype, m, n, k, block_m, block_n, block_k))
|
||||
continue;
|
||||
|
||||
bool success = false;
|
||||
@@ -218,8 +233,16 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
|
||||
// Decide the number of TMA multicasts and whether broadcast on A
|
||||
MulticastConfig best_multicast_config = {1, false};
|
||||
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
|
||||
auto [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
|
||||
gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms);
|
||||
|
||||
// NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B
|
||||
// TODO: Optimize it
|
||||
if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN)
|
||||
is_legal_on_a = false;
|
||||
if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN)
|
||||
is_legal_on_b = false;
|
||||
|
||||
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
|
||||
bool order[2] = {false, true};
|
||||
if (best_block_m > best_block_n)
|
||||
@@ -236,14 +259,14 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
int best_num_stages = 0;
|
||||
SharedMemoryConfig best_smem_config;
|
||||
for (int num_stages = 32; num_stages > 0; -- num_stages) {
|
||||
if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
|
||||
if (not ArchSpec::is_num_stages_legal(mma_kind, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
|
||||
continue;
|
||||
|
||||
best_smem_config = get_smem_config<ArchSpec>(gemm_type, kernel_type,
|
||||
m, n, k,
|
||||
best_block_m, best_block_n, block_k,
|
||||
major_a, major_b,
|
||||
ab_dtype, cd_dtype,
|
||||
mma_kind, cd_dtype,
|
||||
num_stages, best_multicast_config);
|
||||
if (best_smem_config.smem_size <= smem_capacity) {
|
||||
best_num_stages = num_stages;
|
||||
@@ -255,7 +278,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
// Recompute the minimal number of SMs required
|
||||
// NOTES: less L2 cache usage and less GPU frequency drop
|
||||
int num_min_sms = num_sms;
|
||||
if (ArchSpec::should_minimize_num_sms()) {
|
||||
if (get_env<int>("DG_JIT_MINIMIZE_NUM_SMS", 0)) {
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves);
|
||||
num_min_sms = align(num_min_sms, best_multicast_config.num_multicast);
|
||||
DG_HOST_ASSERT(num_min_sms <= num_sms);
|
||||
@@ -264,7 +287,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
const auto& config = GemmConfig {
|
||||
.gemm_type = gemm_type,
|
||||
.kernel_type = kernel_type,
|
||||
.ab_dtype = ab_dtype,
|
||||
.mma_kind = mma_kind,
|
||||
.a_dtype = a_dtype,
|
||||
.b_dtype = b_dtype,
|
||||
.cd_dtype = cd_dtype,
|
||||
.major_a = major_a,
|
||||
.major_b = major_b,
|
||||
@@ -284,21 +309,22 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
|
||||
// Only SM100 BF16 kernels support tensor core control
|
||||
if (config.tc_util < 100)
|
||||
DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and ab_dtype == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and mma_kind == MmaKind::BF16);
|
||||
|
||||
// Print configs for the first time
|
||||
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
|
||||
auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b,
|
||||
ab_dtype, cd_dtype, with_accumulation, num_sms);
|
||||
mma_kind, a_dtype, b_dtype, cd_dtype, with_accumulation, num_sms);
|
||||
static std::set<decltype(key)> printed;
|
||||
if (printed.count(key) == 0) {
|
||||
printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, "
|
||||
"A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, "
|
||||
"A major: %d, B major: %d, MMA kind: %d, A dtype: %s, B dtype: %s, CD dtype: %s, accumulation: %d, "
|
||||
"SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, "
|
||||
"SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, "
|
||||
"swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n",
|
||||
static_cast<int>(gemm_type), static_cast<int>(kernel_type), m, n, k, num_groups,
|
||||
static_cast<int>(major_a), static_cast<int>(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype),
|
||||
static_cast<int>(major_a), static_cast<int>(major_b), static_cast<int>(mma_kind),
|
||||
c10::toString(a_dtype), c10::toString(b_dtype), c10::toString(cd_dtype),
|
||||
static_cast<int>(with_accumulation), num_sms, best_block_m, best_block_n, block_k,
|
||||
best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast,
|
||||
static_cast<int>(best_multicast_config.is_multicast_on_a),
|
||||
|
||||
@@ -53,18 +53,18 @@ struct SM100ArchSpec {
|
||||
}
|
||||
|
||||
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
|
||||
const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) {
|
||||
const int& block_m, const int& block_n, const MmaKind& mma_kind) {
|
||||
constexpr int num_utccp_aligned_elems = 128;
|
||||
switch (ab_dtype) {
|
||||
case torch::kBFloat16: return {0, 0};
|
||||
case torch::kFloat8_e4m3fn: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)};
|
||||
switch (mma_kind) {
|
||||
case MmaKind::BF16: return {0, 0};
|
||||
case MmaKind::MXFP8FP4: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)};
|
||||
default: DG_HOST_UNREACHABLE("Unknown dtype");
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_block_size_legal(const KernelType& kernel_type,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
|
||||
const int& m, const int& n, const int& k,
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
// Layout A/D does not support `block_n % 16 != 0`
|
||||
@@ -82,7 +82,7 @@ struct SM100ArchSpec {
|
||||
// Check tensor memory validity
|
||||
int sf_block_m = 0, sf_block_n = 0;
|
||||
if (kernel_type == KernelType::Kernel1D1D) {
|
||||
const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype);
|
||||
const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind);
|
||||
sf_block_m = sf_block_m_, sf_block_n = sf_block_n_;
|
||||
}
|
||||
if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512)
|
||||
@@ -90,19 +90,15 @@ struct SM100ArchSpec {
|
||||
|
||||
// NOTES: when B is MN-major, we restrict `block_n` to multiples of 64,
|
||||
// since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA
|
||||
return major_b == cute::UMMA::Major::K or (block_n * c10::elementSize(ab_dtype)) % 64 == 0;
|
||||
return major_b == cute::UMMA::Major::K or (block_n * get_element_size(mma_kind)) % 64 == 0;
|
||||
}
|
||||
|
||||
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
|
||||
const int& num_stages,
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool should_minimize_num_sms() {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
|
||||
const int& m, const int& n, const int& block_m, const int& block_n,
|
||||
const int& num_sms) {
|
||||
@@ -129,14 +125,14 @@ struct SM100ArchSpec {
|
||||
|
||||
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
|
||||
const int& block_m, const int& block_n, const int& block_k,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) {
|
||||
if (ab_dtype == torch::kBFloat16)
|
||||
const MmaKind& mma_kind, const at::ScalarType& cd_dtype) {
|
||||
if (mma_kind == MmaKind::BF16)
|
||||
return {0, 0};
|
||||
|
||||
int smem_sfa_per_stage = 0;
|
||||
int smem_sfb_per_stage = 0;
|
||||
if (kernel_type == KernelType::Kernel1D1D) {
|
||||
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype);
|
||||
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind);
|
||||
smem_sfa_per_stage = sf_block_m * 4;
|
||||
smem_sfb_per_stage = sf_block_n * 4;
|
||||
} else {
|
||||
|
||||
@@ -60,7 +60,7 @@ struct SM90ArchSpec {
|
||||
|
||||
static bool is_block_size_legal(const KernelType& kernel_type,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
|
||||
const int& m, const int& n, const int& k,
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
// SM90 FP32 output does not support `block_m == 256`
|
||||
@@ -89,19 +89,15 @@ struct SM90ArchSpec {
|
||||
return block_m <= 128 or block_n <= 128;
|
||||
}
|
||||
|
||||
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
|
||||
const int& num_stages,
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
// Unrolling both stages and `num_former_iters` will cause large code size
|
||||
if (ab_dtype == torch::kFloat8_e4m3fn and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4)
|
||||
if (mma_kind == MmaKind::MXFP8FP4 and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4)
|
||||
return num_stages <= 4;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool should_minimize_num_sms() {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
|
||||
const int& m, const int& n, const int& block_m, const int& block_n,
|
||||
const int& num_sms) {
|
||||
@@ -134,8 +130,8 @@ struct SM90ArchSpec {
|
||||
|
||||
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
|
||||
const int& block_m, const int& block_n, const int& block_k,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) {
|
||||
if (ab_dtype == torch::kBFloat16)
|
||||
const MmaKind& mma_kind, const at::ScalarType& cd_dtype) {
|
||||
if (mma_kind == MmaKind::BF16)
|
||||
return {0, 0};
|
||||
|
||||
// NOTES: 128 is for 2D TMA alignment requirement
|
||||
|
||||
@@ -37,11 +37,12 @@ static std::string to_string(const cute::UMMA::Major& major) {
|
||||
|
||||
static std::string to_string(const GemmType& type) {
|
||||
switch (type) {
|
||||
case GemmType::Normal: return "GemmType::Normal";
|
||||
case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous";
|
||||
case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked";
|
||||
case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous";
|
||||
case GemmType::Batched: return "GemmType::Batched";
|
||||
case GemmType::Normal: return "GemmType::Normal";
|
||||
case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous";
|
||||
case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked";
|
||||
case GemmType::MGroupedContiguousWithPsumLayout: return "GemmType::MGroupedContiguousWithPsumLayout";
|
||||
case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous";
|
||||
case GemmType::Batched: return "GemmType::Batched";
|
||||
}
|
||||
DG_HOST_UNREACHABLE("Unknown GEMM type");
|
||||
}
|
||||
@@ -51,6 +52,8 @@ static std::string to_string(const at::ScalarType& dtype) {
|
||||
case torch::kInt: return "int";
|
||||
case torch::kFloat: return "float";
|
||||
case torch::kBFloat16: return "cutlass::bfloat16_t";
|
||||
case torch::kFloat8_e4m3fn: return "cutlass::float_e4m3_t";
|
||||
case kPackedFP4: return "cutlass::detail::float_e2m1_unpacksmem_t";
|
||||
default: DG_HOST_UNREACHABLE("Unsupported dtype");
|
||||
}
|
||||
}
|
||||
@@ -65,6 +68,7 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
|
||||
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
case kPackedFP4: return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;
|
||||
default: DG_HOST_UNREACHABLE("Unsupported dtype");
|
||||
}
|
||||
}
|
||||
@@ -98,6 +102,10 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
if (swizzle_mode != 0)
|
||||
smem_inner_dim = swizzle_mode / elem_size;
|
||||
|
||||
// Inner dim must be a multiple of 64B for .b4x16_p64
|
||||
if (t.scalar_type() == kPackedFP4)
|
||||
DG_HOST_ASSERT(gmem_inner_dim % 128 == 0);
|
||||
|
||||
CUtensorMap tensor_map;
|
||||
const cuuint64_t gmem_dims[2] = {static_cast<cuuint64_t>(gmem_inner_dim), static_cast<cuuint64_t>(gmem_outer_dim)};
|
||||
const cuuint32_t smem_dims[2] = {static_cast<cuuint32_t>(smem_inner_dim), static_cast<cuuint32_t>(smem_outer_dim)};
|
||||
@@ -126,6 +134,10 @@ static CUtensorMap make_tma_3d_desc(const torch::Tensor& t,
|
||||
if (swizzle_mode != 0)
|
||||
smem_dim_0 = swizzle_mode / elem_size;
|
||||
|
||||
// Inner dim must be a multiple of 64B for .b4x16_p64
|
||||
if (t.scalar_type() == kPackedFP4)
|
||||
DG_HOST_ASSERT(gmem_dim_0 % 128 == 0);
|
||||
|
||||
CUtensorMap tensor_map;
|
||||
const cuuint64_t gmem_dims[3] = {static_cast<cuuint64_t>(gmem_dim_0), static_cast<cuuint64_t>(gmem_dim_1), static_cast<cuuint64_t>(gmem_dim_2),};
|
||||
const cuuint32_t smem_dims[3] = {static_cast<cuuint32_t>(smem_dim_0), static_cast<cuuint32_t>(smem_dim_1), static_cast<cuuint32_t>(smem_dim_2)};
|
||||
@@ -204,7 +216,7 @@ static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
|
||||
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
const torch::Tensor& t,
|
||||
int shape_mn, int shape_k,
|
||||
const int& block_mn, const int& block_k,
|
||||
const int& block_mn, const int& gran_k,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
@@ -215,7 +227,7 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
|
||||
shape_mn = get_tma_aligned_size(shape_mn, static_cast<int>(t.element_size()));
|
||||
return make_tma_2d_desc(t,
|
||||
shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
|
||||
shape_mn, ceil_div(shape_k, gran_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
|
||||
block_mn, 1,
|
||||
shape_mn,
|
||||
swizzle_mode, swizzle_base,
|
||||
|
||||
@@ -79,11 +79,11 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 64);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), c.has_value(),
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
@@ -104,7 +104,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
|
||||
// Launch
|
||||
const SM100BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
@@ -124,16 +124,25 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const torch::Tensor& grouped_layout,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 64);
|
||||
const std::string& compiled_dims,
|
||||
const bool& use_psum_layout,
|
||||
const std::optional<int>& expected_m_for_psum_layout) {
|
||||
const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
|
||||
|
||||
// NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`.
|
||||
// Otherwise, treat the contiguous layout as a whole.
|
||||
const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m;
|
||||
const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1;
|
||||
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::KernelNoSF,
|
||||
gemm_type, KernelType::KernelNoSF,
|
||||
// NOTES: `num_groups` is 1, since the contiguous layout is seen as a whole
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), false,
|
||||
m_for_config, n, k, num_groups_for_config, major_a, major_b,
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
@@ -154,14 +163,14 @@ static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
|
||||
|
||||
// Launch
|
||||
const SM100BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = m_indices.data_ptr(),
|
||||
.grouped_layout = grouped_layout.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_cd = tensor_map_cd
|
||||
@@ -179,11 +188,11 @@ static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a,
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 64);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::KernelNoSF,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), false,
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
@@ -204,7 +213,7 @@ static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a,
|
||||
|
||||
// Launch
|
||||
const SM100BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
@@ -243,7 +252,8 @@ static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a,
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::KGroupedContiguous, KernelType::KernelNoSF,
|
||||
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
|
||||
torch::kBFloat16, d.scalar_type(), c.has_value(),
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Create tensor descriptors
|
||||
@@ -290,7 +300,8 @@ static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Batched, KernelType::KernelNoSF,
|
||||
b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K,
|
||||
torch::kBFloat16, tensor_d.scalar_type(), false,
|
||||
tensor_a.scalar_type(), tensor_b.scalar_type(),
|
||||
tensor_d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
|
||||
@@ -337,7 +348,8 @@ static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Batched, KernelType::KernelNoSF,
|
||||
b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN,
|
||||
torch::kBFloat16, tensor_d.scalar_type(), false,
|
||||
tensor_a.scalar_type(), tensor_b.scalar_type(),
|
||||
tensor_d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
|
||||
|
||||
@@ -15,10 +15,11 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8Gemm1D1DRuntime> {
|
||||
class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8FP4Gemm1D1DRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
int gran_k_a, gran_k_b;
|
||||
const std::string& compiled_dims;
|
||||
const std::optional<std::string>& epilogue_type;
|
||||
|
||||
@@ -41,6 +42,7 @@ using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d1d_impl<
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
@@ -50,12 +52,14 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
|
||||
args.gran_k_a, args.gran_k_b,
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
@@ -64,7 +68,8 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation,
|
||||
to_string(args.gemm_config.a_dtype), to_string(args.gemm_config.b_dtype), to_string(args.gemm_config.cd_dtype),
|
||||
get_default_epilogue_type(args.epilogue_type));
|
||||
}
|
||||
|
||||
@@ -78,19 +83,20 @@ static void __instantiate_kernel() {{
|
||||
}
|
||||
};
|
||||
|
||||
static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims,
|
||||
const std::optional<std::string>& epilogue_type = std::nullopt) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const int& gran_k_a, const int& gran_k_b,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims,
|
||||
const std::optional<std::string>& epilogue_type = std::nullopt) {
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D1D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& cd = c.value_or(d);
|
||||
@@ -110,14 +116,16 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
config.block_m, gran_k_a, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, 1, 0);
|
||||
config.block_n, gran_k_b, 1, 0);
|
||||
|
||||
// Launch
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.gran_k_a = gran_k_a,
|
||||
.gran_k_b = gran_k_b,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = epilogue_type,
|
||||
.gemm_config = config,
|
||||
@@ -131,24 +139,33 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_cd = tensor_map_cd
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code);
|
||||
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& grouped_layout,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& gran_k_a, const int& gran_k_b,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims,
|
||||
const bool& use_psum_layout,
|
||||
const std::optional<int>& expected_m_for_psum_layout) {
|
||||
const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
|
||||
|
||||
// NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`.
|
||||
// Otherwise, treat the contiguous layout as a whole.
|
||||
const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m;
|
||||
const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1;
|
||||
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D1D,
|
||||
// NOTES: `num_groups` is 1, since the contiguous layout is seen as a whole
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
gemm_type, KernelType::Kernel1D1D,
|
||||
m_for_config, n, k, num_groups_for_config, major_a, major_b,
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Create tensor descriptors
|
||||
@@ -168,45 +185,48 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
config.block_m, gran_k_a, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, num_groups, 0);
|
||||
config.block_n, gran_k_b, num_groups, 0);
|
||||
|
||||
// Launch kernel
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.gran_k_a = gran_k_a,
|
||||
.gran_k_b = gran_k_b,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = m_indices.data_ptr(),
|
||||
.grouped_layout = grouped_layout.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_cd = tensor_map_cd
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code);
|
||||
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& expected_m,
|
||||
const int& gran_k_a, const int& gran_k_b,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::Kernel1D1D,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Create tensor descriptors
|
||||
@@ -226,14 +246,16 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, num_groups, 0);
|
||||
config.block_m, gran_k_a, num_groups, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, num_groups, 0);
|
||||
config.block_n, gran_k_b, num_groups, 0);
|
||||
|
||||
// Launch kernel
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.gran_k_a = gran_k_a,
|
||||
.gran_k_b = gran_k_b,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
@@ -247,19 +269,19 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_cd = tensor_map_cd
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code);
|
||||
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n,
|
||||
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n,
|
||||
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
|
||||
|
||||
int sum_k = 0, sum_sf_k = 0;
|
||||
@@ -274,7 +296,8 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Create tensor descriptors
|
||||
@@ -299,9 +322,11 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
|
||||
config.block_n, config.block_k, 1, 0);
|
||||
|
||||
// Launch kernel
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = sum_k,
|
||||
.num_groups = num_groups,
|
||||
.gran_k_a = 128,
|
||||
.gran_k_b = 128,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
@@ -315,9 +340,9 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_cd = tensor_map_cd
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code);
|
||||
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
@@ -330,7 +355,8 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Batched, KernelType::Kernel1D1D,
|
||||
m, n, k, batch_size, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
|
||||
@@ -364,9 +390,11 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
config.block_n, config.block_k, batch_size, 0);
|
||||
|
||||
// Launch
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = batch_size,
|
||||
.gran_k_a = 128,
|
||||
.gran_k_b = 128,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
@@ -380,9 +408,9 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_cd = tensor_map_cd
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
149
csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp
Normal file
149
csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp
Normal file
@@ -0,0 +1,149 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100BF16HCPrenormGemmRuntime final: public LaunchRuntime<SM100BF16HCPrenormGemmRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k;
|
||||
int block_m, block_n, block_k;
|
||||
int num_splits;
|
||||
int swizzle_cd_mode;
|
||||
int num_stages;
|
||||
int num_mma_threads, num_cast_and_reduce_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
float* sqr_sum;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_tf32_hc_prenorm_gemm_impl<
|
||||
{}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
args.n, args.k,
|
||||
args.block_m, args.block_n, args.block_k,
|
||||
args.num_splits,
|
||||
args.swizzle_cd_mode,
|
||||
args.num_stages,
|
||||
args.num_mma_threads, args.num_cast_and_reduce_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& sqr_sum,
|
||||
const int& m, const int& n, const int& k,
|
||||
const int& num_splits) {
|
||||
constexpr int block_m = 64;
|
||||
constexpr int block_k = 64;
|
||||
constexpr int num_mma_threads = 128;
|
||||
constexpr int num_cast_and_reduce_threads = 128;
|
||||
|
||||
const int block_n = align(n, 16);
|
||||
DG_HOST_ASSERT(n <= block_n);
|
||||
DG_HOST_ASSERT(n <= 128 and n % 8 == 0);
|
||||
DG_HOST_ASSERT(k % block_k == 0);
|
||||
|
||||
const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float));
|
||||
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k,
|
||||
block_m, block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
|
||||
get_swizzle_mode(block_k, a.element_size()), 0,
|
||||
true);
|
||||
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k,
|
||||
block_n, block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
|
||||
get_swizzle_mode(block_k, b.element_size()), 0,
|
||||
true);
|
||||
const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n,
|
||||
block_m, block_n,
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
swizzle_cd_mode)
|
||||
: make_tma_3d_desc(d, n, m, num_splits,
|
||||
block_n, block_m, 1,
|
||||
static_cast<int>(d.stride(-2)),
|
||||
static_cast<int>(d.stride(-3)),
|
||||
swizzle_cd_mode);
|
||||
|
||||
// Calculate stages
|
||||
int num_stages = 12, smem_size = 0;
|
||||
while (num_stages > 0) {
|
||||
const int smem_a_per_stage = block_m * block_k * static_cast<int>(sizeof(nv_bfloat16));
|
||||
const int smem_b_per_stage = block_n * block_k * static_cast<int>(sizeof(float));
|
||||
const int smem_cd = block_m * swizzle_cd_mode;
|
||||
const int smem_barriers = (num_stages * 4 + 1) * 8;
|
||||
const int smem_tmem_ptr = 4;
|
||||
smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages +
|
||||
smem_cd + smem_barriers + smem_tmem_ptr;
|
||||
|
||||
if (smem_size <= SM100ArchSpec::smem_capacity)
|
||||
break;
|
||||
-- num_stages;
|
||||
}
|
||||
DG_HOST_ASSERT(num_stages > 0);
|
||||
|
||||
// Print configs
|
||||
if (get_env("DG_JIT_DEBUG", 0)) {
|
||||
printf("M: %d, N: %d, K: %d -> "
|
||||
"block M: %d, block N: %d, block K: %d, split K: %d"
|
||||
"stages: %d, shared memory: %d, swizzle CD: %d\n",
|
||||
m, n, k, block_m, block_n, block_k, num_splits,
|
||||
num_stages, smem_size, swizzle_cd_mode);
|
||||
}
|
||||
|
||||
// Launch
|
||||
const SM100BF16HCPrenormGemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.block_m = block_m, .block_n = block_n, .block_k = block_k,
|
||||
.num_splits = num_splits,
|
||||
.swizzle_cd_mode = swizzle_cd_mode,
|
||||
.num_stages = num_stages,
|
||||
.num_mma_threads = num_mma_threads,
|
||||
.num_cast_and_reduce_threads = num_cast_and_reduce_threads,
|
||||
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size, 1),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.sqr_sum = sqr_sum.data_ptr<float>()
|
||||
};
|
||||
const auto& code = SM100BF16HCPrenormGemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code);
|
||||
SM100BF16HCPrenormGemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -79,13 +79,11 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
|
||||
const auto& aligned_k = align(k, 64);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), c.has_value(),
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
@@ -107,7 +105,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
@@ -138,7 +136,8 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), false,
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
@@ -192,7 +191,8 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::KernelNoSF,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), false,
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
@@ -253,7 +253,8 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a,
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::KGroupedContiguous, KernelType::KernelNoSF,
|
||||
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
|
||||
torch::kBFloat16, d.scalar_type(), c.has_value(),
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Create tensor descriptors
|
||||
@@ -300,7 +301,8 @@ static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Batched, KernelType::KernelNoSF,
|
||||
b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K,
|
||||
torch::kBFloat16, tensor_d.scalar_type(), false,
|
||||
tensor_a.scalar_type(), tensor_b.scalar_type(),
|
||||
tensor_d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
|
||||
@@ -346,7 +348,8 @@ static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Batched, KernelType::KernelNoSF,
|
||||
b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN,
|
||||
torch::kBFloat16, tensor_d.scalar_type(), false,
|
||||
tensor_a.scalar_type(), tensor_b.scalar_type(),
|
||||
tensor_d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
|
||||
|
||||
@@ -88,7 +88,8 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D1D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
@@ -138,7 +139,7 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
@@ -156,7 +157,8 @@ static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Te
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, max_k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
|
||||
@@ -87,11 +87,11 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D2D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
@@ -118,7 +118,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
// Launch
|
||||
const SM90FP8Gemm1D2DRuntime::Args& args = {
|
||||
.major_sfb = major_sfb,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = epilogue_type,
|
||||
@@ -148,11 +148,11 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
@@ -179,7 +179,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
|
||||
// Launch
|
||||
const SM90FP8Gemm1D2DRuntime::Args& args = {
|
||||
.major_sfb = major_sfb,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
@@ -207,14 +207,14 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
@@ -241,7 +241,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
|
||||
// Launch
|
||||
const SM90FP8Gemm1D2DRuntime::Args& args = {
|
||||
.major_sfb = major_sfb,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
@@ -261,4 +261,71 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
|
||||
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& batch_size, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Batched, KernelType::Kernel1D2D,
|
||||
m, n, k, batch_size, major_a, major_b,
|
||||
a.scalar_type(), b.scalar_type(),
|
||||
d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
|
||||
const auto& tensor_map_a = make_tma_3d_desc(a, k, m, batch_size,
|
||||
config.block_k, load_block_m, 1,
|
||||
a.stride(1),
|
||||
a.stride(0),
|
||||
config.smem_config.swizzle_a_mode);
|
||||
|
||||
const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n);
|
||||
const auto& tensor_map_b = make_tma_3d_desc(b, k, n, batch_size,
|
||||
config.block_k, load_block_n, 1,
|
||||
b.stride(1),
|
||||
b.stride(0),
|
||||
config.smem_config.swizzle_b_mode);
|
||||
|
||||
const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m);
|
||||
const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n);
|
||||
const auto& tensor_map_d = make_tma_3d_desc(d, n, m, batch_size,
|
||||
store_block_n, store_block_m, 1,
|
||||
d.stride(1), d.stride(0),
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, batch_size, 0);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D2DRuntime::Args& args = {
|
||||
.major_sfb = major_sfb,
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = batch_size,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
|
||||
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
152
csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp
Normal file
152
csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp
Normal file
@@ -0,0 +1,152 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90BF16HCPrenormGemmRuntime final: public LaunchRuntime<SM90BF16HCPrenormGemmRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k;
|
||||
int block_m, block_n, block_k;
|
||||
int num_splits;
|
||||
int swizzle_cd_mode;
|
||||
int num_stages;
|
||||
int num_math_threads, num_tma_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
float* sqr_sum;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_tf32_hc_prenorm_gemm_impl<
|
||||
{}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
args.n, args.k,
|
||||
args.block_m, args.block_n, args.block_k,
|
||||
args.num_splits,
|
||||
args.swizzle_cd_mode,
|
||||
args.num_stages,
|
||||
args.num_math_threads, args.num_tma_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& sqr_sum,
|
||||
const int& m, const int& n, const int& k,
|
||||
const int& num_splits) {
|
||||
constexpr int block_m = 64;
|
||||
constexpr int block_k = 64;
|
||||
constexpr int num_math_threads = 128;
|
||||
constexpr int num_tma_threads = 128;
|
||||
constexpr int num_threads = num_math_threads + num_tma_threads;
|
||||
|
||||
const int block_n = align(n, 16);
|
||||
DG_HOST_ASSERT(n <= block_n);
|
||||
// Only support small N for now
|
||||
DG_HOST_ASSERT(n <= 32 and n % 8 == 0);
|
||||
DG_HOST_ASSERT(k % block_k == 0);
|
||||
|
||||
const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float));
|
||||
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k,
|
||||
block_m, block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
|
||||
get_swizzle_mode(block_k, a.element_size()), 0,
|
||||
true);
|
||||
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k,
|
||||
block_n, block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
|
||||
get_swizzle_mode(block_k, b.element_size()), 0,
|
||||
true);
|
||||
const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n,
|
||||
block_m, block_n,
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
swizzle_cd_mode)
|
||||
: make_tma_3d_desc(d, n, m, num_splits,
|
||||
block_n, block_m, 1,
|
||||
static_cast<int>(d.stride(-2)),
|
||||
static_cast<int>(d.stride(-3)),
|
||||
swizzle_cd_mode);
|
||||
|
||||
// Calculate stages
|
||||
int num_stages = 12, smem_size = 0;
|
||||
while (num_stages > 0) {
|
||||
const int smem_a_per_stage = block_m * block_k * static_cast<int>(sizeof(nv_bfloat16));
|
||||
const int smem_b_per_stage = block_n * block_k * static_cast<int>(sizeof(float));
|
||||
const int smem_cd = block_m * swizzle_cd_mode;
|
||||
const int smem_barriers = num_stages * 2 * 8;
|
||||
smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages +
|
||||
smem_cd + smem_barriers;
|
||||
|
||||
if (smem_size <= SM90ArchSpec::smem_capacity)
|
||||
break;
|
||||
-- num_stages;
|
||||
}
|
||||
DG_HOST_ASSERT(num_stages > 0);
|
||||
|
||||
// Print configs
|
||||
if (get_env("DG_JIT_DEBUG", 0)) {
|
||||
printf("M: %d, N: %d, K: %d -> "
|
||||
"block M: %d, block N: %d, block K: %d, split K: %d"
|
||||
"stages: %d, shared memory: %d, swizzle CD: %d\n",
|
||||
m, n, k, block_m, block_n, block_k, num_splits,
|
||||
num_stages, smem_size, swizzle_cd_mode);
|
||||
}
|
||||
|
||||
smem_size = SM90ArchSpec::smem_capacity;
|
||||
|
||||
// Launch
|
||||
const SM90BF16HCPrenormGemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.block_m = block_m, .block_n = block_n, .block_k = block_k,
|
||||
.num_splits = num_splits,
|
||||
.swizzle_cd_mode = swizzle_cd_mode,
|
||||
.num_stages = num_stages,
|
||||
.num_math_threads = num_math_threads,
|
||||
.num_tma_threads = num_tma_threads,
|
||||
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size, 1),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.sqr_sum = sqr_sum.data_ptr<float>()
|
||||
};
|
||||
const auto& code = SM90BF16HCPrenormGemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code);
|
||||
SM90BF16HCPrenormGemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -37,7 +37,6 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a,
|
||||
const bool& accumulate) {
|
||||
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
cudaDataType_t scale_type = CUDA_R_32F;
|
||||
const int& math_sms = device_runtime->get_num_sms();
|
||||
|
||||
// Operation description
|
||||
cublasLtMatmulDesc_t desc;
|
||||
@@ -45,9 +44,13 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a,
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
|
||||
|
||||
#if DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE
|
||||
const int& math_sms = device_runtime->get_num_sms();
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms)));
|
||||
|
||||
#if DG_FP8_COMPATIBLE
|
||||
#endif
|
||||
|
||||
#if DG_FP8_COMPATIBLE and DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE
|
||||
bool fp8_fast_accumulate = false;
|
||||
if (a.scalar_type() == torch::kFloat8_e4m3fn)
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate)));
|
||||
|
||||
@@ -174,13 +174,13 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
const int& logits_stride,
|
||||
const int& block_table_stride,
|
||||
const int& num_sms,
|
||||
const int& num_math_warp_groups) {
|
||||
const int& split_kv) {
|
||||
const int num_specialized_threads = 128;
|
||||
const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64);
|
||||
const int num_math_warp_groups = split_kv / mma_m;
|
||||
const int num_math_threads = num_math_warp_groups * 128;
|
||||
const int num_extra_threads = device_runtime->get_arch_major() == 10 ? 128 : 0;
|
||||
const int num_q_stages = 3, num_kv_stages = 3;
|
||||
const int split_kv = num_math_warp_groups * block_kv;
|
||||
DG_HOST_ASSERT(logits_stride % (num_math_warp_groups * block_kv) == 0);
|
||||
const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3);
|
||||
DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0);
|
||||
|
||||
// Construct TMAs
|
||||
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
|
||||
@@ -196,23 +196,39 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
next_n * num_heads, 1, next_n * num_heads, 0);
|
||||
|
||||
// Calculate shared memory size
|
||||
const int swizzle_alignment = head_dim * 8;
|
||||
int smem_size = 0;
|
||||
if (device_runtime->get_arch_major() == 9) {
|
||||
const int swizzle_alignment = head_dim * 8;
|
||||
|
||||
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
|
||||
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
|
||||
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);
|
||||
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
|
||||
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
|
||||
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);
|
||||
|
||||
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv_cache.element_size());
|
||||
const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast<int>(kv_cache_scales.element_size()), swizzle_alignment);
|
||||
const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment);
|
||||
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv_cache.element_size());
|
||||
const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast<int>(kv_cache_scales.element_size()), swizzle_alignment);
|
||||
const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment);
|
||||
|
||||
// Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90
|
||||
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
|
||||
const int smem_tmem_ptr = 4;
|
||||
// Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90
|
||||
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
|
||||
const int smem_tmem_ptr = 4;
|
||||
|
||||
const int smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr;
|
||||
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
||||
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
|
||||
smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr;
|
||||
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
||||
} else {
|
||||
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
|
||||
const int smem_kv_size_per_stage = split_kv * head_dim * static_cast<int>(kv_cache.element_size());
|
||||
const int smem_kv_scale_size_per_stage = split_kv * static_cast<int>(kv_cache_scales.element_size());
|
||||
const int smem_weight_size_per_stage = next_n * num_heads * static_cast<int>(weights.element_size());
|
||||
|
||||
const int smem_barriers = (num_q_stages + num_kv_stages) * 2 * 8;
|
||||
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
|
||||
const int smem_tmem_ptr = 4;
|
||||
|
||||
smem_size = num_q_stages * (smem_q_size_per_stage + smem_weight_size_per_stage) +
|
||||
num_kv_stages * (smem_kv_size_per_stage + smem_kv_scale_size_per_stage) +
|
||||
smem_barriers + smem_umma_barriers + smem_tmem_ptr;
|
||||
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
|
||||
}
|
||||
|
||||
// Launch
|
||||
const SMXXFP8PagedMQALogitsRuntime::Args& args = {
|
||||
@@ -238,7 +254,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
.num_specialized_threads = num_specialized_threads,
|
||||
.num_math_threads = num_math_threads,
|
||||
.launch_args = LaunchArgs(num_sms,
|
||||
num_specialized_threads + num_math_threads + num_extra_threads,
|
||||
num_specialized_threads + num_math_threads,
|
||||
smem_size)
|
||||
};
|
||||
const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "apis/attention.hpp"
|
||||
#include "apis/einsum.hpp"
|
||||
#include "apis/hyperconnection.hpp"
|
||||
#include "apis/gemm.hpp"
|
||||
#include "apis/layout.hpp"
|
||||
#include "apis/runtime.hpp"
|
||||
@@ -15,8 +16,10 @@
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "DeepGEMM C++ library";
|
||||
|
||||
// TODO: make SM80 incompatible issues raise errors
|
||||
deep_gemm::attention::register_apis(m);
|
||||
deep_gemm::einsum::register_apis(m);
|
||||
deep_gemm::hyperconnection::register_apis(m);
|
||||
deep_gemm::gemm::register_apis(m);
|
||||
deep_gemm::layout::register_apis(m);
|
||||
deep_gemm::runtime::register_apis(m);
|
||||
|
||||
@@ -2,9 +2,16 @@
|
||||
|
||||
#include <torch/version.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// `torch::kFloat8_e4m3fn` is supported since PyTorch 2.1
|
||||
#define DG_FP8_COMPATIBLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 1))
|
||||
|
||||
// `cuTensorMapEncodeTiled` is supported since CUDA Driver API 12.1
|
||||
#define DG_TENSORMAP_COMPATIBLE (CUDA_VERSION >= 12010)
|
||||
#define DG_TENSORMAP_COMPATIBLE (CUDA_VERSION >= 12010)
|
||||
|
||||
// `cublasGetErrorString` is supported since CUDA Runtime API 11.4.2
|
||||
#define DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE (CUDART_VERSION >= 11042)
|
||||
|
||||
// `CUBLASLT_MATMUL_DESC_FAST_ACCUM` and `CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET` are supported since CUDA Runtime API 11.8
|
||||
#define DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE (CUDART_VERSION >= 11080)
|
||||
@@ -5,6 +5,8 @@
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "compatibility.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class DGException final : public std::exception {
|
||||
@@ -74,6 +76,25 @@ do { \
|
||||
#endif
|
||||
|
||||
#ifndef DG_CUBLASLT_CHECK
|
||||
|
||||
#if !DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE
|
||||
inline const char* cublasGetStatusString(cublasStatus_t status) {
|
||||
switch(status) {
|
||||
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
|
||||
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
|
||||
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
|
||||
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
|
||||
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
|
||||
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
|
||||
case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
|
||||
default: return "Unknown cuBLAS error";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#define DG_CUBLASLT_CHECK(cmd) \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
|
||||
@@ -36,15 +36,34 @@ static bool fp8_requires_k_major() {
|
||||
// Tensor utils
|
||||
template <int N>
|
||||
static auto get_shape(const torch::Tensor& t) {
|
||||
DG_HOST_ASSERT(t.dim() == N);
|
||||
return [&t] <size_t... Is> (std::index_sequence<Is...>) {
|
||||
return std::make_tuple(static_cast<int>(t.sizes()[Is])...);
|
||||
}(std::make_index_sequence<N>());
|
||||
}
|
||||
|
||||
static std::tuple<int, int> check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) {
|
||||
auto [mn, k] = get_shape<2>(ab);
|
||||
if (ab.scalar_type() != torch::kFloat8_e4m3fn) {
|
||||
DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10);
|
||||
major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2);
|
||||
}
|
||||
return std::make_tuple(mn, k);
|
||||
}
|
||||
|
||||
static std::tuple<int, int, int> check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) {
|
||||
auto [num_groups, mn, k] = get_shape<3>(ab);
|
||||
if (ab.scalar_type() != torch::kFloat8_e4m3fn) {
|
||||
DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10);
|
||||
major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2);
|
||||
}
|
||||
return std::make_tuple(num_groups, mn, k);
|
||||
}
|
||||
|
||||
// Recipe
|
||||
static std::tuple<int, int, int>
|
||||
get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) {
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
const auto arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat);
|
||||
return {1, 128, 128};
|
||||
@@ -70,7 +89,7 @@ static torch::Tensor check_sf_layout(const torch::Tensor& sf,
|
||||
DG_HOST_ASSERT(sf.scalar_type() == type_check.value());
|
||||
|
||||
// Always do shape checks
|
||||
const auto& sf_dtype = sf.scalar_type();
|
||||
const auto sf_dtype = sf.scalar_type();
|
||||
DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt);
|
||||
DG_HOST_ASSERT(sf.dim() == static_cast<int>(num_groups.has_value()) + 2);
|
||||
if (num_groups.has_value())
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
// TODO: Use `torch::kFloat4_e2m1fn_x2`
|
||||
constexpr auto kPackedFP4 = torch::kUInt8;
|
||||
|
||||
template <typename T>
|
||||
static T ceil_div(const T& a, const T& b) {
|
||||
return (a + b - 1) / b;
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
import torch
|
||||
from torch.version import cuda as cuda_version
|
||||
from packaging import version
|
||||
|
||||
# Set some default environment provided at setup
|
||||
try:
|
||||
@@ -29,9 +27,15 @@ from ._C import (
|
||||
cublaslt_gemm_tn, cublaslt_gemm_tt,
|
||||
)
|
||||
|
||||
if version.parse(cuda_version) >= version.parse('12.1'):
|
||||
try:
|
||||
# DeepGEMM Kernels
|
||||
from ._C import (
|
||||
# FP8 FP4 GEMMs
|
||||
fp8_fp4_gemm_nt, fp8_fp4_gemm_nn,
|
||||
fp8_fp4_gemm_tn, fp8_fp4_gemm_tt,
|
||||
m_grouped_fp8_fp4_gemm_nt_contiguous,
|
||||
m_grouped_fp8_fp4_gemm_nn_contiguous,
|
||||
m_grouped_fp8_fp4_gemm_nt_masked,
|
||||
# FP8 GEMMs
|
||||
fp8_gemm_nt, fp8_gemm_nn,
|
||||
fp8_gemm_tn, fp8_gemm_tt,
|
||||
@@ -55,6 +59,8 @@ if version.parse(cuda_version) >= version.parse('12.1'):
|
||||
fp8_mqa_logits,
|
||||
get_paged_mqa_logits_metadata,
|
||||
fp8_paged_mqa_logits,
|
||||
# Hyperconnection kernels
|
||||
tf32_hc_prenorm_gemm,
|
||||
# Layout kernels
|
||||
transform_sf_into_required_layout,
|
||||
get_mk_alignment_for_contiguous_layout
|
||||
@@ -64,6 +70,9 @@ if version.parse(cuda_version) >= version.parse('12.1'):
|
||||
# TODO: remove these later
|
||||
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked
|
||||
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
|
||||
except ImportError:
|
||||
# Expected behavior for CUDA runtime version before 12.1
|
||||
pass
|
||||
|
||||
# Some utils
|
||||
from . import testing
|
||||
@@ -71,7 +80,10 @@ from . import utils
|
||||
from .utils import *
|
||||
|
||||
# Legacy Triton kernels for A100
|
||||
from . import legacy
|
||||
try:
|
||||
from . import legacy
|
||||
except Exception as e:
|
||||
print(f'Failed to load legacy DeepGEMM A100 Triton kernels: {e}')
|
||||
|
||||
# Initialize CPP modules
|
||||
def _find_cuda_home() -> str:
|
||||
@@ -97,4 +109,4 @@ _C.init(
|
||||
_find_cuda_home() # CUDA home
|
||||
)
|
||||
|
||||
__version__ = '2.2.0'
|
||||
__version__ = '2.3.0'
|
||||
|
||||
@@ -51,6 +51,8 @@ struct Scheduler {
|
||||
uint32_t current_group_idx = 0;
|
||||
// Only used for masked layout
|
||||
uint32_t current_m_cumsum = 0;
|
||||
// Only used for countiguous psum layout
|
||||
uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0;
|
||||
// Only used for k-grouped layout
|
||||
uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
|
||||
uint32_t next_group_idx, next_shape_k;
|
||||
@@ -72,12 +74,16 @@ struct Scheduler {
|
||||
current_shape_k = shape_k;
|
||||
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
|
||||
num_blocks = num_m_blocks * num_n_blocks;
|
||||
} else if (kGemmType == GemmType::MGroupedContiguous) {
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
|
||||
num_blocks = num_m_blocks * num_n_blocks;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::MGroupedMasked) {
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::KGroupedContiguous) {
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
|
||||
this->grouped_layout = grouped_layout;
|
||||
current_psum_m = __ldg(grouped_layout);
|
||||
num_m_blocks = ceil_div(current_psum_m, BLOCK_M);
|
||||
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
|
||||
this->grouped_layout = grouped_layout;
|
||||
get_next_k_group(current_group_idx, current_shape_k);
|
||||
next_group_idx = current_group_idx + 1;
|
||||
@@ -131,7 +137,7 @@ struct Scheduler {
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
|
||||
const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0;
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
|
||||
const auto offset = kWithGroupOffset ? current_group_idx : 0;
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
|
||||
@@ -172,6 +178,28 @@ struct Scheduler {
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx);
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
|
||||
while (true) {
|
||||
// Within current group
|
||||
if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks)
|
||||
break;
|
||||
|
||||
// Move to check the next group
|
||||
if (++ current_group_idx == kNumGroups)
|
||||
return false;
|
||||
|
||||
// NOTES: `num_m_blocks` varies with the increase of the group index
|
||||
last_psum_m = align(current_psum_m, 128u);
|
||||
current_psum_m = __ldg(grouped_layout + current_group_idx);
|
||||
current_m_block_cumsum += num_m_blocks;
|
||||
num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M);
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx);
|
||||
|
||||
// NOTES: `last_psum_m` is aligned with 128
|
||||
m_block_idx += last_psum_m / BLOCK_M;
|
||||
DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M");
|
||||
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
|
||||
while (true) {
|
||||
// End of the task
|
||||
@@ -248,6 +276,9 @@ struct Scheduler {
|
||||
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
|
||||
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx);
|
||||
} else {
|
||||
// Unreachable
|
||||
DG_TRAP_ONLY_DEVICE_ASSERT(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -97,7 +97,8 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
|
||||
const auto& layout_type = to_umma_layout_type<kMajorMode, kSwizzleMode, kUseBase32, dtype_t>();
|
||||
const auto& num_non_contiguous = 128 / get_atom_base(layout_type);
|
||||
if constexpr (kMajorMode == cute::UMMA::Major::K) {
|
||||
// NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
|
||||
// NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)`
|
||||
// also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis
|
||||
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
|
||||
|
||||
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
|
||||
@@ -131,8 +132,8 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
|
||||
}
|
||||
|
||||
__device__ __forceinline__
|
||||
uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sf_id) {
|
||||
desc.a_sf_id_ = sf_id, desc.b_sf_id_ = sf_id;
|
||||
uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) {
|
||||
desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id;
|
||||
return static_cast<uint64_t>(static_cast<uint32_t>(desc)) << 32;
|
||||
}
|
||||
|
||||
@@ -154,6 +155,20 @@ __device__ __forceinline__ void tcgen05_after_thread_sync() {
|
||||
asm volatile("tcgen05.fence::after_thread_sync;");
|
||||
}
|
||||
|
||||
__device__ __forceinline__
|
||||
void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) {
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier);
|
||||
asm volatile(
|
||||
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
|
||||
:
|
||||
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
|
||||
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
|
||||
"r"(mbarrier_addr), "l"(cache_hint)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
|
||||
// UMMA versions with relaxed assertions
|
||||
struct SM100_MMA_F16BF16_SS {
|
||||
__device__ static void
|
||||
@@ -231,4 +246,21 @@ struct SM100_MMA_MXF8F6F4_2x1SM_SS {
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_MMA_F16BF16_WS_SS {
|
||||
__device__ static void
|
||||
fma(uint64_t const& desc_a,
|
||||
uint64_t const& desc_b,
|
||||
uint32_t const& tmem_c,
|
||||
uint32_t const& scale_c,
|
||||
uint64_t const& desc) {
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
|
||||
"}\n"
|
||||
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace `deep_gemm::sm100`
|
||||
|
||||
@@ -152,6 +152,51 @@ struct BF16MMASelector {
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
template <int N_, typename MMA>
|
||||
struct TF32MMARS {
|
||||
|
||||
template <size_t ...Idx>
|
||||
__forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
|
||||
using namespace cute::SM90::GMMA;
|
||||
MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
call_fma_impl(reinterpret_cast<uint32_t*>(a), desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = N_;
|
||||
static constexpr int K = 8;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
template <int N, bool kUseRS = true>
|
||||
struct TF32MMASelector {
|
||||
|
||||
static constexpr auto select_mma() {
|
||||
using namespace cute::SM90::GMMA;
|
||||
if constexpr (kUseRS) {
|
||||
if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN();
|
||||
if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN();
|
||||
if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN();
|
||||
if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN();
|
||||
if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN();
|
||||
if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN();
|
||||
DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N");
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto select_type() {
|
||||
if constexpr (kUseRS) {
|
||||
return TF32MMARS<N, decltype(select_mma())>();
|
||||
} else {
|
||||
DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now");
|
||||
}
|
||||
}
|
||||
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
|
||||
@@ -2,14 +2,36 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
enum class GemmType {
|
||||
Normal = 0,
|
||||
MGroupedContiguous = 1,
|
||||
MGroupedMasked = 2,
|
||||
KGroupedContiguous = 3,
|
||||
Batched = 4
|
||||
enum class MmaKind {
|
||||
BF16 = 0,
|
||||
MXFP8FP4 = 1,
|
||||
};
|
||||
|
||||
constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) {
|
||||
switch (mma_kind) {
|
||||
case MmaKind::BF16: return 2;
|
||||
case MmaKind::MXFP8FP4: return 1;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
enum class GemmType {
|
||||
Normal = 0,
|
||||
MGroupedContiguous = 1,
|
||||
MGroupedMasked = 2,
|
||||
KGroupedContiguous = 3,
|
||||
Batched = 4,
|
||||
MGroupedContiguousWithPsumLayout = 5,
|
||||
};
|
||||
|
||||
constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) {
|
||||
switch (gemm_type) {
|
||||
case GemmType::MGroupedContiguous: return true;
|
||||
case GemmType::MGroupedContiguousWithPsumLayout: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
enum class KernelType {
|
||||
Kernel1D1D = 0,
|
||||
Kernel1D2D = 1,
|
||||
|
||||
@@ -148,6 +148,10 @@ __device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t
|
||||
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) {
|
||||
asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
|
||||
}
|
||||
|
||||
template <typename old_t>
|
||||
__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {
|
||||
auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
|
||||
|
||||
@@ -388,7 +388,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
||||
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
|
||||
|
||||
// The pipeline stage
|
||||
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
|
||||
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
|
||||
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
||||
|
||||
// Store into shared memory
|
||||
|
||||
@@ -14,6 +14,7 @@ namespace deep_gemm {
|
||||
using namespace deep_gemm::sm100;
|
||||
|
||||
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
||||
uint32_t kGranKA, uint32_t kGranKB,
|
||||
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups,
|
||||
@@ -22,7 +23,8 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
||||
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
|
||||
GemmType kGemmType, bool kWithAccumulation,
|
||||
typename a_dtype_t, typename b_dtype_t, typename cd_dtype_t,
|
||||
typename epilogue_type_t>
|
||||
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
||||
sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
@@ -45,16 +47,21 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
constexpr uint32_t WAVE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
||||
constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M;
|
||||
constexpr uint32_t kNumTMAStoreStages = 2;
|
||||
constexpr uint32_t kNumSFStagesPerLoad = sizeof(uint32_t) / sizeof(cutlass::float_ue8m0_t);
|
||||
constexpr uint32_t kNumUTCCPAlignedElems = 128;
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
|
||||
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
|
||||
|
||||
constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4;
|
||||
constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4;
|
||||
DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A");
|
||||
DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B");
|
||||
|
||||
// Overwrite shape constants if the compiler gives
|
||||
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
||||
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
||||
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
||||
const uint32_t shape_sf_k = ceil_div(shape_k, BLOCK_K * kNumSFStagesPerLoad);
|
||||
const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4);
|
||||
const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4);
|
||||
|
||||
// Utils
|
||||
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
|
||||
@@ -78,8 +85,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
// Share memory sizes
|
||||
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
|
||||
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
|
||||
constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
|
||||
constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
|
||||
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
|
||||
@@ -89,7 +96,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
|
||||
|
||||
// NOTES: Make sure we have enough shared memory for UMMA padding
|
||||
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t);
|
||||
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
|
||||
|
||||
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
|
||||
@@ -118,10 +125,10 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
||||
});
|
||||
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
return reinterpret_cast<a_dtype_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
});
|
||||
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
return reinterpret_cast<b_dtype_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
});
|
||||
|
||||
// SFA/SFB shared memory
|
||||
@@ -225,28 +232,31 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
|
||||
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
|
||||
if constexpr (kMajorA == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t, kIsBatchedMM>(
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
|
||||
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx);
|
||||
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::float_e4m3_t, kIsBatchedMM>(
|
||||
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
|
||||
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t, kIsBatchedMM>(
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
|
||||
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::float_e4m3_t, kIsBatchedMM>(
|
||||
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
|
||||
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx);
|
||||
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v<a_dtype_t, cutlass::float_e4m3_t> ? 1 : 2) +
|
||||
SMEM_B_SIZE_PER_STAGE / (std::is_same_v<b_dtype_t, cutlass::float_e4m3_t> ? 1 : 2);
|
||||
|
||||
// Issue SFA and SFB TMAs at certain stages
|
||||
// No swizzling, so one TMA for one SF is enough
|
||||
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0) {
|
||||
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
|
||||
tma_copy<BLOCK_M, 1, 0>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M,
|
||||
scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad)));
|
||||
scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad)));
|
||||
num_arrival_bytes += BLOCK_M * sizeof(uint32_t);
|
||||
}
|
||||
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
|
||||
tma_copy<BLOCK_N, 1, 0>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N,
|
||||
scheduler.template get_global_idx<true, IndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx));
|
||||
num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t);
|
||||
scheduler.template get_global_idx<true, IndexType::SF_K>(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx));
|
||||
num_arrival_bytes += BLOCK_N * sizeof(uint32_t);
|
||||
}
|
||||
|
||||
// Arrive at full barriers
|
||||
@@ -260,9 +270,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
// TODO: refactor `UMMA_M` calculation
|
||||
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
|
||||
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
|
||||
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
|
||||
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e4m3_t, cutlass::float_e4m3_t,
|
||||
float, cutlass::float_ue8m0_t,
|
||||
constexpr uint32_t UMMA_K = 32;
|
||||
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t,
|
||||
UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
||||
auto sf_desc = make_sf_desc(nullptr);
|
||||
|
||||
@@ -313,19 +322,20 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
|
||||
// Do SF copy at certain stages
|
||||
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
|
||||
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
|
||||
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
|
||||
|
||||
// SFA and SFB copy
|
||||
// TODO: process shared memory descriptor by addition
|
||||
// TODO: process shared memory descriptor by addition
|
||||
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
|
||||
const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad;
|
||||
if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
|
||||
}
|
||||
}
|
||||
const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad;
|
||||
if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
@@ -337,17 +347,20 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
|
||||
// Issue UMMA in the leader CTA
|
||||
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_MXF8F6F4_SS, SM100_MMA_MXF8F6F4_2x1SM_SS>;
|
||||
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx);
|
||||
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
|
||||
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
|
||||
if (cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx);
|
||||
const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx);
|
||||
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id);
|
||||
|
||||
b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset");
|
||||
a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t>(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K);
|
||||
a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K);
|
||||
mma_t::fma(a_desc, b_desc,
|
||||
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
|
||||
k_block_idx > 0 or k > 0,
|
||||
@@ -391,11 +404,14 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
full_barriers[stage_idx]->wait(phase);
|
||||
|
||||
// Transpose for UTCCP at certain stages
|
||||
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0) {
|
||||
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
|
||||
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
|
||||
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
}
|
||||
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
|
||||
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
|
||||
@@ -454,7 +470,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
|
||||
|
||||
// The pipeline stage
|
||||
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
|
||||
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
|
||||
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
|
||||
|
||||
// Store into shared memory
|
||||
|
||||
@@ -143,7 +143,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
|
||||
return {block_q_idx + gridDim.x, q_iter_idx + 1};
|
||||
};
|
||||
uint32_t seq_k_start[BLOCK_Q];
|
||||
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
|
||||
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
|
||||
uint32_t start = cute::numeric_limits<uint32_t>::max();
|
||||
uint32_t end = cute::numeric_limits<uint32_t>::min();
|
||||
@@ -152,8 +152,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
|
||||
seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
|
||||
seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
|
||||
start = min(start, min(seq_k_start[i], seq_len_kv));
|
||||
end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv));
|
||||
end = max(end, min(seq_k_end[i], seq_len_kv));
|
||||
}
|
||||
start = start / 4 * 4;
|
||||
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
|
||||
@@ -278,9 +279,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
const auto& v_offset = lane_idx;
|
||||
|
||||
// Preload weights
|
||||
constexpr uint32_t kNumWeightsInReg = 52;
|
||||
constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads);
|
||||
float weights[BLOCK_Q][kNumWeightsInReg];
|
||||
DG_STATIC_ASSERT(kNumWeightsInReg <= kNumHeads and kNumWeightsInReg % 4 == 0, "Invalid kNumWeightsInReg");
|
||||
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
|
||||
|
||||
while (block_q_idx < num_q_blocks) {
|
||||
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
||||
@@ -337,7 +338,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
float* accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
|
||||
auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
|
||||
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
@@ -367,14 +368,14 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
sum_1 = transform_smem(j + 2, sum_1);
|
||||
}
|
||||
|
||||
float result = sum_0.x + sum_0.y + sum_1.x + sum_1.y;
|
||||
result *= scale_kv;
|
||||
auto sum = __fadd2_rn(sum_0, sum_1);
|
||||
float result = scale_kv * (sum.x + sum.y);
|
||||
|
||||
// Store into the global memory
|
||||
// NOTES: we have redundant writes here, consider more carefully
|
||||
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
|
||||
if constexpr (kIsCompressedLogits) {
|
||||
if (kv_offset + v_offset >= seq_k_start[i])
|
||||
if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i])
|
||||
logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result;
|
||||
} else {
|
||||
logits[q_idx * stride_logits + kv_offset + v_offset] = result;
|
||||
|
||||
@@ -22,8 +22,9 @@ template <uint32_t kNextN, uint32_t kNumHeads,
|
||||
bool kIsContextLens2D,
|
||||
uint32_t kNumQStages, uint32_t kNumKVStages,
|
||||
uint32_t SPLIT_KV,
|
||||
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads>
|
||||
__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads + 128, 1)
|
||||
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
||||
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
|
||||
__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
|
||||
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
const uint64_t logits_stride, const uint64_t block_table_stride,
|
||||
const uint32_t* context_lens, float* logits,
|
||||
@@ -40,9 +41,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
const auto& lane_idx = get_lane_idx();
|
||||
|
||||
// Prefetch TMA descriptors
|
||||
static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
|
||||
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
||||
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_q);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
||||
@@ -54,78 +53,58 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
// Shared memory configs
|
||||
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
||||
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float);
|
||||
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
|
||||
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
|
||||
static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
|
||||
constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
|
||||
|
||||
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
||||
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
|
||||
static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
|
||||
constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
|
||||
|
||||
static constexpr uint32_t SMEM_UMMA_SIZE = kNumMathWarpGroups * 2 * 8 + static_cast<uint32_t>(sizeof(uint32_t));
|
||||
|
||||
// Align to swizzling alignment bytes
|
||||
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
||||
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
||||
|
||||
// Q data and barriers on shared memory
|
||||
// Q and KV data on shared memory
|
||||
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
||||
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
|
||||
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
|
||||
|
||||
// Separate math warpgroups and tma load warps into KV groups
|
||||
// Each math warpgroup corresponds to a tma load warp
|
||||
const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
|
||||
|
||||
// Per group KV data and barriers on shared memory
|
||||
const auto& smem_kv_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
|
||||
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * i);
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
||||
constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages;
|
||||
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
||||
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
|
||||
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
|
||||
|
||||
// UMMA barriers and TMEM pointer on shared memory
|
||||
auto umma_barrier_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kNumMathWarpGroups);
|
||||
// Barriers and TMEM pointer on shared memory
|
||||
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
||||
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
||||
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
|
||||
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
|
||||
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
|
||||
const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
|
||||
auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
|
||||
auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
|
||||
|
||||
constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups;
|
||||
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
|
||||
const bool& is_math_warp = (warp_idx < (kNumMathThreads / 32)); // 0 ~ 16
|
||||
const bool& is_tma_load_warp = (warp_idx >= (kNumMathThreads / 32) and warp_idx < (kNumMathThreads / 32 + 4)); // 16 ~ 20
|
||||
const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 4)); // 20
|
||||
const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4);
|
||||
const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4);
|
||||
const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1);
|
||||
|
||||
// Initialize barriers
|
||||
if (is_tma_load_warp and cute::elect_one_sync()) {
|
||||
if (kv_group_idx == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
||||
full_q_barriers[i]->init(1);
|
||||
empty_q_barriers[i]->init(kNumMathThreads);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
||||
full_q_barriers[i]->init(1);
|
||||
empty_q_barriers[i]->init(kNumMathThreads);
|
||||
}
|
||||
if (kv_group_idx < kNumMathWarpGroups) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
||||
full_kv_barriers[i]->init(1);
|
||||
empty_kv_barriers[i]->init(128);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
||||
full_kv_barriers[i]->init(1);
|
||||
empty_kv_barriers[i]->init(kNumMathThreads);
|
||||
}
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
@@ -144,12 +123,13 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
__syncthreads();
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr uint32_t kNumSpecializedRegisters = 32;
|
||||
constexpr uint32_t kNumMathRegisters = 104;
|
||||
constexpr uint32_t kNumSpecializedRegisters = 40;
|
||||
constexpr uint32_t kNumMathRegisters = 232;
|
||||
|
||||
// Scheduler
|
||||
auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups>(batch_size, blockIdx.x, context_lens, schedule_meta);
|
||||
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
|
||||
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
|
||||
auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit>(batch_size, blockIdx.x, context_lens, schedule_meta);
|
||||
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
|
||||
|
||||
// Q and KV pipeline
|
||||
const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
||||
@@ -161,19 +141,18 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
||||
|
||||
// UMMA settings
|
||||
// Construct instruction with layout F
|
||||
constexpr uint32_t UMMA_M = 64;
|
||||
// Construct instruction with layout D
|
||||
constexpr uint32_t UMMA_M = 128;
|
||||
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
|
||||
constexpr uint32_t UMMA_N = kNextN * kNumHeads;
|
||||
DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
|
||||
|
||||
if (is_tma_load_warp) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
||||
if (kv_group_idx >= kNumMathWarpGroups)
|
||||
return;
|
||||
|
||||
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
|
||||
if (kv_group_idx == 0 and cute::elect_one_sync()) {
|
||||
if (cute::elect_one_sync()) {
|
||||
tma_copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
|
||||
tma_copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx);
|
||||
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
||||
@@ -199,6 +178,14 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
kv_idx = next_kv_idx;
|
||||
num_kv = next_num_kv;
|
||||
|
||||
// Read KV block index
|
||||
// TODO: deal with `-1`?
|
||||
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
|
||||
kv_block_idx_ptr = 0;
|
||||
kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0);
|
||||
}
|
||||
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
|
||||
|
||||
// Wait Q consumer release and issue TMA Q
|
||||
if (prefetch_q) {
|
||||
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
||||
@@ -206,25 +193,26 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
issue_tma_q(q_stage_idx, q_idx + 1);
|
||||
}
|
||||
|
||||
// Read KV block index
|
||||
// TODO: deal with `-1`?
|
||||
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
|
||||
kv_block_idx_ptr = 0;
|
||||
kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
|
||||
__ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0);
|
||||
}
|
||||
const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
|
||||
int kv_block_idx[kNumBlocksPerSplit];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumBlocksPerSplit; ++ i)
|
||||
kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
|
||||
kv_block_idx_ptr += kNumBlocksPerSplit;
|
||||
|
||||
// Wait KV consumer release
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
||||
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
||||
|
||||
// Issue TMA KV
|
||||
if (cute::elect_one_sync()) {
|
||||
tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
|
||||
smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx);
|
||||
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
|
||||
smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumBlocksPerSplit; ++ i) {
|
||||
tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
|
||||
smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i,
|
||||
0, 0, 1, kv_block_idx[i]);
|
||||
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
|
||||
smem_kv_scales[kv_stage_idx] + BLOCK_KV * i,
|
||||
0, kv_block_idx[i]);
|
||||
}
|
||||
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
@@ -245,32 +233,26 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
uint32_t q_idx = batch_size, kv_idx;
|
||||
uint32_t next_q_idx, next_kv_idx, next_num_kv;
|
||||
uint32_t q_stage_idx, q_phase;
|
||||
uint32_t umma_phase = 0;
|
||||
|
||||
auto smem_kv = PatternVisitor([&](const uint32_t& stage_idx) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_SIZE_PER_STAGE * stage_idx);
|
||||
});
|
||||
uint32_t umma_phase = 1;
|
||||
|
||||
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
|
||||
if (q_idx != next_q_idx) {
|
||||
if (q_idx != next_q_idx)
|
||||
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
||||
}
|
||||
|
||||
q_idx = next_q_idx;
|
||||
kv_idx = next_kv_idx;
|
||||
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
||||
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
||||
|
||||
DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size");
|
||||
DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
||||
empty_umma_barriers[i]->wait(umma_phase & 1);
|
||||
empty_umma_barriers[i]->wait(umma_phase);
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
|
||||
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
||||
smem_kv[kv_stage_idx] + i * SMEM_KV_PIPE_SIZE, 0, k * UMMA_K);
|
||||
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
|
||||
auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
||||
smem_q[q_stage_idx], 0, k * UMMA_K);
|
||||
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
|
||||
@@ -285,10 +267,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
|
||||
// Offsets
|
||||
const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
|
||||
float weights[kNextN][kNumHeads / 4];
|
||||
const auto& sub_warp_offset = (warp_idx % 4) * 16;
|
||||
const auto& v_0_offset = lane_idx / 4 + 0;
|
||||
const auto& v_1_offset = lane_idx / 4 + 8;
|
||||
const uint32_t thread_idx = threadIdx.x;
|
||||
|
||||
// Weights
|
||||
constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads));
|
||||
float weights[kNextN][kNumWeightsInReg];
|
||||
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
|
||||
|
||||
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
|
||||
uint32_t q_idx = batch_size, kv_idx;
|
||||
@@ -310,9 +294,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
// Read weights
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNextN; ++ i) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
||||
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
||||
for (uint32_t j = 0; j < kNumWeightsInReg; ++ j)
|
||||
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -321,75 +304,80 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
kv_idx = next_kv_idx;
|
||||
|
||||
// Calculate KV offset in advance
|
||||
auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
|
||||
auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV;
|
||||
|
||||
// Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
|
||||
// Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]`
|
||||
// Wait TMA KV arrival
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
||||
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
||||
|
||||
// Read per-KV scales
|
||||
auto scale_kv = make_float2(ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset),
|
||||
ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset));
|
||||
float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx);
|
||||
|
||||
empty_umma_barriers[warpgroup_idx]->arrive();
|
||||
// Wait UMMA arrival
|
||||
full_umma_barriers[warpgroup_idx]->wait(umma_phase & 1);
|
||||
full_umma_barriers[warpgroup_idx]->wait(umma_phase);
|
||||
umma_phase ^= 1;
|
||||
|
||||
// Release KV empty
|
||||
empty_kv_barriers[kv_stage_idx]->arrive();
|
||||
|
||||
// Reduce over the head dim and store
|
||||
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
|
||||
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
||||
constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN;
|
||||
uint32_t shifted_accum[kNumLDTMElems];
|
||||
DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM");
|
||||
auto tmem_load = [&](auto... Is) {
|
||||
if constexpr (kNumLDTMElems == 32) {
|
||||
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
|
||||
} else if constexpr (kNumLDTMElems == 64) {
|
||||
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
|
||||
} else if constexpr (kNumLDTMElems == 128) {
|
||||
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
|
||||
}
|
||||
};
|
||||
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
|
||||
empty_umma_barriers[warpgroup_idx]->arrive();
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNextN; ++ i) {
|
||||
// Load from the tensor memory
|
||||
constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128;
|
||||
uint32_t shifted_accum[kNumLDTMElems];
|
||||
DG_STATIC_ASSERT(kNumLDTMElems == 16 or kNumLDTMElems == 32 or kNumLDTMElems == 64, "Invalid LDTM");
|
||||
auto tmem_load = [&](auto... Is) {
|
||||
if constexpr (kNumLDTMElems == 16) {
|
||||
cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
|
||||
} else if constexpr (kNumLDTMElems == 32) {
|
||||
cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
|
||||
} else if constexpr (kNumLDTMElems == 64) {
|
||||
cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
|
||||
}
|
||||
};
|
||||
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
|
||||
|
||||
// Transform
|
||||
const auto& transform_2 = [&](const uint32_t& j, const uint32_t& k, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(*reinterpret_cast<float*>(&shifted_accum[j * 4 + k]), 0),
|
||||
fmaxf(*reinterpret_cast<float*>(&shifted_accum[j * 4 + k + 2]), 0));
|
||||
auto b = make_float2(weights[i][j * 2 + k], weights[i][j * 2 + k]);
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
|
||||
const auto& transform_reg = [&](const uint32_t& j, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
||||
return __ffma2_rn(a, b, sum);
|
||||
};
|
||||
|
||||
// Intra-thread reduction
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads / 8; ++ j) {
|
||||
sum_0 = transform_2(j, 0, sum_0);
|
||||
sum_1 = transform_2(j, 1, sum_1);
|
||||
for (int j = 0; j < kNumWeightsInReg; j += 4) {
|
||||
sum_0 = transform_reg(j, sum_0);
|
||||
sum_1 = transform_reg(j + 2, sum_1);
|
||||
}
|
||||
auto v = __fmul2_rn(__fadd2_rn(sum_0, sum_1), scale_kv);
|
||||
|
||||
// Inter-thread reduction
|
||||
const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
|
||||
ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
|
||||
return __ffma2_rn(a, b, sum);
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++ j) {
|
||||
const auto& offset = 1u << j;
|
||||
v.x += __shfl_xor_sync(0xffffffffu, v.x, offset);
|
||||
v.y += __shfl_xor_sync(0xffffffffu, v.y, offset);
|
||||
for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
|
||||
sum_0 = transform_smem(j, sum_0);
|
||||
sum_1 = transform_smem(j + 2, sum_1);
|
||||
}
|
||||
|
||||
auto sum = __fadd2_rn(sum_0, sum_1);
|
||||
float result = scale_kv * (sum.x + sum.y);
|
||||
|
||||
// Store into the global memory
|
||||
// NOTES: we have redundant writes here, consider more carefully
|
||||
logits[kv_offset + i * logits_stride + v_0_offset] = v.x;
|
||||
logits[kv_offset + i * logits_stride + v_1_offset] = v.y;
|
||||
logits[kv_offset + i * logits_stride + thread_idx] = result;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
345
deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh
Normal file
345
deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh
Normal file
@@ -0,0 +1,345 @@
|
||||
#pragma once
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
|
||||
#include <deep_gemm/common/reduction.cuh>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
#include <deep_gemm/common/sm100_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm100;
|
||||
|
||||
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
|
||||
__device__ __forceinline__
|
||||
uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) {
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase);
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)`
|
||||
// - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)`
|
||||
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
|
||||
constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups;
|
||||
auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
|
||||
auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
|
||||
col ^= row % (kSwizzleMode / kSwizzleBase);
|
||||
|
||||
return row * 128 + col * kSwizzleBase;
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumSplits,
|
||||
uint32_t kSwizzleCDMode,
|
||||
uint32_t kNumStages,
|
||||
uint32_t kNumMMAThreads, uint32_t kNumCastAndReduceThreads>
|
||||
__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1)
|
||||
sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
|
||||
float* sqr_sum) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kNumCastStages = 2;
|
||||
constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
|
||||
constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
|
||||
constexpr auto kMajorA = cute::UMMA::Major::K;
|
||||
constexpr auto kMajorB = cute::UMMA::Major::K;
|
||||
DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages");
|
||||
DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
|
||||
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads");
|
||||
|
||||
// Utils
|
||||
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
const auto lane_idx = get_lane_idx();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
|
||||
// Share memory sizes
|
||||
constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
|
||||
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Real tensor memory size and offsets
|
||||
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
}
|
||||
|
||||
// Data on shared memory (layout as ordered below)
|
||||
// Fill D/A/B pointers
|
||||
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
|
||||
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
||||
});
|
||||
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
||||
});
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
|
||||
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
||||
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
||||
auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
||||
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
||||
auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
|
||||
auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4;
|
||||
|
||||
// Fill the tensor memory pointer
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 4 + 1);
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
// Initialize barriers
|
||||
if (warp_idx == 1 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
full_cast_barriers[i]->init(kNumCastAndReduceThreads);
|
||||
empty_barriers[i]->init(1);
|
||||
empty_cast_barriers[i]->init(1);
|
||||
}
|
||||
tmem_full_barrier->init(1);
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (warp_idx == 2) {
|
||||
// Allocate tensor memory
|
||||
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
|
||||
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
|
||||
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
|
||||
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
|
||||
const uint32_t m_block_idx = block_idx / kNumSplits;
|
||||
const uint32_t k_split_idx = block_idx % kNumSplits;
|
||||
const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
|
||||
const uint32_t m_offset = shape_m * k_split_idx;
|
||||
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
|
||||
|
||||
// Dispatch warps into different roles
|
||||
if (warp_idx < kNumMMAThreads / 32) {
|
||||
// TMA load warp
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
||||
// Wait consumer release
|
||||
const auto& stage_idx = s % kNumStages;
|
||||
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
||||
|
||||
// Compute offsets
|
||||
uint32_t m_idx = m_block_idx * BLOCK_M;
|
||||
uint32_t k_idx = k_offset + s * BLOCK_K;
|
||||
|
||||
// Issue TMAs
|
||||
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
|
||||
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
|
||||
|
||||
// Arrive at full barriers
|
||||
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
|
||||
}
|
||||
}
|
||||
|
||||
// MMA issue warp
|
||||
if (warp_idx == 1) {
|
||||
// Make instruction descriptor
|
||||
constexpr uint32_t UMMA_M = BLOCK_M;
|
||||
constexpr uint32_t UMMA_N = BLOCK_N;
|
||||
constexpr uint32_t UMMA_K = 32 / sizeof(float);
|
||||
constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float);
|
||||
using umma_t = cute::SM100_MMA_TF32_TS<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
|
||||
BLOCK_M, BLOCK_N, kMajorA, kMajorB>;
|
||||
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
|
||||
UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
||||
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
||||
|
||||
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
||||
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
|
||||
const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
||||
|
||||
// Checks for MMA instructions
|
||||
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
|
||||
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
||||
"Invalid MMA instruction shape");
|
||||
|
||||
// Launch MMAs
|
||||
// We can not unroll this part
|
||||
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
||||
// Wait TMA arrival
|
||||
const auto& stage_idx = s % kNumStages;
|
||||
const auto& cast_stage_idx = s % kNumCastStages;
|
||||
full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Issue UMMA
|
||||
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK;
|
||||
const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK;
|
||||
const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK;
|
||||
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
|
||||
umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc);
|
||||
}
|
||||
|
||||
// Commit
|
||||
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_cast_barriers[cast_stage_idx]));
|
||||
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
|
||||
}
|
||||
|
||||
// Commit to epilogue threads
|
||||
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
|
||||
}
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
|
||||
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
||||
DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
||||
|
||||
// Only support layout F (M = 64) and D (M = 128)
|
||||
DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M");
|
||||
|
||||
// Wait UMMA arrival
|
||||
tmem_full_barrier->wait(0);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Load from tensor memory into registers, and write shared memory with STSM
|
||||
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough");
|
||||
|
||||
// Store into shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
||||
// Source and destination memory address
|
||||
uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup;
|
||||
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) + // Base pointer
|
||||
warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset
|
||||
get_swizzled_smem_offset<kSwizzleCDMode>(i, lane_idx); // In-atom offset
|
||||
|
||||
// Load from tensor memory, store into shared memory
|
||||
uint32_t values[kNumElemsPerBankGroup];
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16))
|
||||
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
||||
if constexpr (BLOCK_M == 64)
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Synchronize all threads and issue TMA
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0);
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
if constexpr (kNumSplits == 1) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
|
||||
} else {
|
||||
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
|
||||
}
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
|
||||
// Deallocate tensor memory by warp 1
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
if (warp_idx == 1)
|
||||
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
||||
} else {
|
||||
DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
|
||||
DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads");
|
||||
constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
|
||||
const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32;
|
||||
|
||||
// TODO: make even larger block K
|
||||
DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
|
||||
|
||||
// Launch reductions
|
||||
float2 sum[2] = {float2{0, 0}, float2{0, 0}};
|
||||
#pragma unroll kNumStages
|
||||
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
||||
// Wait TMA arrival
|
||||
const auto& stage_idx = s % kNumStages;
|
||||
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
||||
|
||||
// Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b)
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
|
||||
constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
|
||||
const auto& smem_base_ptr = reinterpret_cast<uint8_t*>(smem_a[stage_idx]) + // Base pointer
|
||||
sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset
|
||||
|
||||
// 4 lanes shared a bank group
|
||||
uint32_t uint32_values[2][kNumLoads];
|
||||
DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads");
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumLoads; i += 2) {
|
||||
auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset<kSwizzleAMode>(i + lane_idx / 16, lane_idx % 16);
|
||||
sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0],
|
||||
uint32_values[0][i + 1], uint32_values[1][i + 1],
|
||||
smem_ptr);
|
||||
}
|
||||
|
||||
// Wait tensor memory empty
|
||||
const auto& cast_stage_idx = s % kNumCastStages;
|
||||
empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1);
|
||||
|
||||
// Cast, reduce and store into tensor memory
|
||||
float2 fp32x2_values[2][kNumLoads];
|
||||
const auto& upper_view = reinterpret_cast<uint32_t*>(&fp32x2_values[0]);
|
||||
const auto& lower_view = reinterpret_cast<uint32_t*>(&fp32x2_values[1]);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumLoads; ++ i) {
|
||||
#pragma unroll
|
||||
for (uint32_t u = 0; u < 2; ++ u) {
|
||||
fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast<nv_bfloat162*>(&uint32_values[u][i]));
|
||||
sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]);
|
||||
}
|
||||
|
||||
// Store upper and lower part at the same time
|
||||
const auto idx_0 = i * 2, idx_1 = i * 2 + 1;
|
||||
cute::SM100_TMEM_STORE_16dp256b1x::copy(
|
||||
upper_view[idx_0], upper_view[idx_1],
|
||||
lower_view[idx_0], lower_view[idx_1],
|
||||
cast_stage_idx * BLOCK_K + i * 8);
|
||||
}
|
||||
cutlass::arch::fence_view_async_tmem_store();
|
||||
|
||||
// Arrive for issuing MMAs
|
||||
tcgen05_before_thread_sync();
|
||||
full_cast_barriers[cast_stage_idx]->arrive();
|
||||
}
|
||||
|
||||
// Intra-warp reduction and write back
|
||||
#pragma unroll
|
||||
for (uint32_t u = 0; u < 2; ++ u) {
|
||||
const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y);
|
||||
const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8;
|
||||
if (lane_idx % 4 == 0 and m_idx < shape_m)
|
||||
sqr_sum[m_offset + m_idx] = reduced_sum;
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
@@ -350,7 +350,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
|
||||
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
|
||||
DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
||||
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
||||
|
||||
@@ -171,20 +171,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
empty_barriers[stage_idx]->wait(phase ^ 1);
|
||||
|
||||
// Issue TMA A
|
||||
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
|
||||
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
|
||||
|
||||
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
||||
auto& full_barrier = *full_barriers[stage_idx];
|
||||
const uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, &full_barrier,
|
||||
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_a, &full_barrier,
|
||||
smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
|
||||
num_tma_multicast_a);
|
||||
num_tma_multicast_a, batch_idx);
|
||||
tma_copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier,
|
||||
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.get_global_idx<kWithGroupOffsetA>(shape_k_scales, 1, k_block_idx),
|
||||
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
|
||||
num_tma_multicast_a);
|
||||
|
||||
// Issue TMA B
|
||||
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, &full_barrier,
|
||||
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_b, &full_barrier,
|
||||
smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
|
||||
num_tma_multicast_b);
|
||||
num_tma_multicast_b, batch_idx);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
|
||||
}
|
||||
}
|
||||
@@ -222,7 +225,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
// Load B scales with math warp-groups
|
||||
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
|
||||
if (threadIdx.x >= 32) {
|
||||
auto previous_group_offset = scheduler.get_global_idx<true>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
|
||||
auto previous_group_offset = scheduler.template get_global_idx<true, IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
|
||||
const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales;
|
||||
const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1;
|
||||
auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb;
|
||||
@@ -413,9 +416,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
||||
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
||||
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||
epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset),
|
||||
scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx));
|
||||
auto n_idx = epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset);
|
||||
auto m_idx = scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx);
|
||||
if constexpr (kGemmType == GemmType::Batched) {
|
||||
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr,
|
||||
n_idx, m_idx, scheduler.current_group_idx);
|
||||
} else {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx);
|
||||
}
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@@ -127,7 +127,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
|
||||
return {block_q_idx + gridDim.x, q_iter_idx + 1};
|
||||
};
|
||||
uint32_t seq_k_start[BLOCK_Q];
|
||||
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
|
||||
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
|
||||
uint32_t start = cute::numeric_limits<uint32_t>::max();
|
||||
uint32_t end = cute::numeric_limits<uint32_t>::min();
|
||||
@@ -136,8 +136,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
|
||||
seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
|
||||
seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
|
||||
start = min(start, min(seq_k_start[i], seq_len_kv));
|
||||
end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv));
|
||||
end = max(end, min(seq_k_end[i], seq_len_kv));
|
||||
}
|
||||
start = start / 4 * 4;
|
||||
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
|
||||
@@ -304,9 +305,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
// NOTES: we have redundant writes here, consider more carefully
|
||||
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
|
||||
if constexpr (kIsCompressedLogits) {
|
||||
if (kv_offset + v_0_offset >= seq_k_start[i])
|
||||
if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i])
|
||||
logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0;
|
||||
if (kv_offset + v_1_offset >= seq_k_start[i])
|
||||
if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i])
|
||||
logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1;
|
||||
} else {
|
||||
logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0;
|
||||
|
||||
@@ -58,7 +58,7 @@ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t ne
|
||||
}
|
||||
|
||||
template <uint32_t kNextN, bool kIsContextLens2D,
|
||||
uint32_t BLOCK_KV, uint32_t kNumMathWarpGroups>
|
||||
uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit>
|
||||
struct PagedMQALogitsScheduler {
|
||||
uint32_t batch_size;
|
||||
const uint32_t* context_lens;
|
||||
@@ -79,8 +79,8 @@ struct PagedMQALogitsScheduler {
|
||||
|
||||
const auto& current_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx);
|
||||
const auto& end_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx + 1);
|
||||
current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumMathWarpGroups;
|
||||
end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumMathWarpGroups;
|
||||
current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit;
|
||||
end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit;
|
||||
|
||||
current_num_kv = get_num_kv(current_q_idx);
|
||||
}
|
||||
@@ -93,7 +93,7 @@ struct PagedMQALogitsScheduler {
|
||||
if (q_idx == end_q_idx and kv_idx == end_kv_idx)
|
||||
return false;
|
||||
|
||||
current_kv_idx += kNumMathWarpGroups;
|
||||
current_kv_idx += kNumBlocksPerSplit;
|
||||
if (current_kv_idx >= current_num_kv) {
|
||||
++ current_q_idx;
|
||||
current_kv_idx = 0;
|
||||
|
||||
287
deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh
Normal file
287
deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh
Normal file
@@ -0,0 +1,287 @@
|
||||
#pragma once
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <deep_gemm/common/reduction.cuh>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm90;
|
||||
|
||||
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
|
||||
__device__ __forceinline__
|
||||
uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) {
|
||||
constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase;
|
||||
|
||||
const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange;
|
||||
|
||||
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
|
||||
constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups;
|
||||
auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
|
||||
auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
|
||||
col ^= row % kGroupsInSwizzleRange;
|
||||
|
||||
return (row * kNumBankGroups + col) % kGroupsInSwizzleRange;
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumSplits,
|
||||
uint32_t kSwizzleCDMode,
|
||||
uint32_t kNumStages,
|
||||
uint32_t kNumMathThreads, uint32_t kNumTMAThreads>
|
||||
__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1)
|
||||
sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
|
||||
float* sqr_sum) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// kSwizzleAMode and kSwizzleBMode must be 128 for now
|
||||
constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
|
||||
constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
|
||||
DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K");
|
||||
DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
|
||||
DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode");
|
||||
|
||||
DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
|
||||
DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads");
|
||||
|
||||
// Utils
|
||||
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
const auto lane_idx = get_lane_idx();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
|
||||
// Share memory sizes
|
||||
constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
|
||||
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
}
|
||||
|
||||
// Data on shared memory (layout as ordered below)
|
||||
// Fill D/A/B pointers
|
||||
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
|
||||
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
||||
});
|
||||
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
||||
});
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
||||
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
||||
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
||||
|
||||
// Initialize barriers
|
||||
if (warp_idx == 1 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(128);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
|
||||
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
|
||||
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
|
||||
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
|
||||
const uint32_t m_block_idx = block_idx / kNumSplits;
|
||||
const uint32_t k_split_idx = block_idx % kNumSplits;
|
||||
const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
|
||||
const uint32_t m_offset = shape_m * k_split_idx;
|
||||
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
|
||||
constexpr uint32_t kNumTMARegisters = 40;
|
||||
constexpr uint32_t kNumMathRegisters = 256;
|
||||
|
||||
// TMA load warp
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
||||
// Wait consumer release
|
||||
const auto& stage_idx = s % kNumStages;
|
||||
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
||||
|
||||
// Compute offsets
|
||||
uint32_t m_idx = m_block_idx * BLOCK_M;
|
||||
uint32_t k_idx = k_offset + s * BLOCK_K;
|
||||
|
||||
// Issue TMAs
|
||||
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
|
||||
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
|
||||
|
||||
// Arrive at full barriers
|
||||
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
|
||||
}
|
||||
|
||||
for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) {
|
||||
const auto& stage_idx = s % kNumStages;
|
||||
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
||||
}
|
||||
} else if (warp_idx < kNumMathThreads / 32) {
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
|
||||
DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
|
||||
constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
|
||||
constexpr uint32_t WGMMA_M = 64;
|
||||
constexpr uint32_t WGMMA_N = BLOCK_N;
|
||||
constexpr uint32_t WGMMA_K = 8;
|
||||
|
||||
using WGMMA = typename TF32MMASelector<WGMMA_N, true>::type;
|
||||
float accum[WGMMA::kNumAccum] = {0};
|
||||
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
|
||||
constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
|
||||
float sqr_sum_acc_0 = 0;
|
||||
float sqr_sum_acc_1 = 0;
|
||||
|
||||
#pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2
|
||||
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
||||
// Wait TMA arrival
|
||||
const auto& stage_idx = s % kNumStages;
|
||||
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
||||
|
||||
constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128;
|
||||
constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K;
|
||||
|
||||
float a[kNumRegPerWgmma * kNumWgmmaPerBlockK];
|
||||
// Assume swizzle A mode is 128
|
||||
DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
|
||||
|
||||
// Load BF16 A fragment from shared memory into registers, and transpose to FP32
|
||||
uint32_t row = warp_idx * 16 + lane_idx / 4;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumLoads; ++ i) {
|
||||
// Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a
|
||||
uint32_t bank_group_idx = (row ^ i) % 8;
|
||||
nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
|
||||
nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
|
||||
|
||||
uint32_t elem_offset = lane_idx % 4;
|
||||
nv_bfloat16 a_bf16[kNumRegPerWgmma];
|
||||
a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset];
|
||||
a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4];
|
||||
a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset];
|
||||
a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4];
|
||||
|
||||
auto a_bf16x2_ptr = reinterpret_cast<nv_bfloat162*>(a_bf16);
|
||||
auto a_float2_ptr = reinterpret_cast<float2*>(a);
|
||||
float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]);
|
||||
float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]);
|
||||
a_float2_ptr[i * 2 + 0] = a_float2_0;
|
||||
a_float2_ptr[i * 2 + 1] = a_float2_1;
|
||||
sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x;
|
||||
sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y;
|
||||
}
|
||||
|
||||
warpgroup_wait<0>();
|
||||
if (s > 0)
|
||||
empty_barriers[(s - 1) % kNumStages]->arrive();
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
|
||||
constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float);
|
||||
constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K;
|
||||
DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K");
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) {
|
||||
auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1);
|
||||
}
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
}
|
||||
|
||||
const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0);
|
||||
const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1);
|
||||
|
||||
const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4);
|
||||
if (lane_idx % 4 == 0) {
|
||||
if (m_idx < shape_m)
|
||||
sqr_sum[m_offset + m_idx] = reduced_sum_0;
|
||||
if (m_idx + 8 < shape_m)
|
||||
sqr_sum[m_offset + m_idx + 8] = reduced_sum_1;
|
||||
}
|
||||
warpgroup_wait<0>();
|
||||
empty_barriers[(num_total_stages-1) % kNumStages]->arrive();
|
||||
|
||||
// Write accum to shared memory
|
||||
// Every 2 threads (one pair) will write to the same bank group (16 bytes).
|
||||
// Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d
|
||||
uint32_t is_odd_pair = lane_idx / 2 % 2;
|
||||
|
||||
// Four threads per group; write the data to the same row.
|
||||
uint32_t row_idx = lane_idx / 4;
|
||||
|
||||
// Even/odd index pairs write to the same column, we need to reorder idx:
|
||||
// group even pair indices consecutively, and likewise for odd ones.
|
||||
uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx;
|
||||
|
||||
auto shifted_smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) +
|
||||
(warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows
|
||||
lane_idx % 2 * 8; // One thread of a pair writes 8 bytes
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) {
|
||||
// Get the swizzled bank group index (16 bytes per group)
|
||||
uint32_t bank_group_idx = get_swizzled_bank_group_idx<kSwizzleCDMode>(i + is_odd_pair, reordered_pair_idx);
|
||||
auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group
|
||||
|
||||
// 0/1 write to the same row, 2/3 write to another row
|
||||
auto values = reinterpret_cast<uint32_t*>(accum + i * 2);
|
||||
st_shared(smem_ptr, values[0], values[1]);
|
||||
st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier::sync(128, 1);
|
||||
|
||||
// Issue TMA stores
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
if constexpr (kNumSplits == 1) {
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
|
||||
} else {
|
||||
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
|
||||
}
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
@@ -5,6 +5,8 @@ from typing import Iterable
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
if denominator == 0: # Which means that all elements in x and y are 0
|
||||
return 0.0
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
from .._C import (
|
||||
get_tma_aligned_size,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
get_mn_major_tma_aligned_tensor,
|
||||
get_mn_major_tma_aligned_packed_ue8m0_tensor,
|
||||
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
|
||||
)
|
||||
try:
|
||||
from .._C import (
|
||||
get_tma_aligned_size,
|
||||
get_mn_major_tma_aligned_tensor,
|
||||
get_mn_major_tma_aligned_packed_ue8m0_tensor,
|
||||
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
|
||||
)
|
||||
except ImportError:
|
||||
# Expected behavior for CUDA runtime version before 12.1
|
||||
pass
|
||||
|
||||
# Valid for all CUDA versions
|
||||
from .._C import get_mk_alignment_for_contiguous_layout
|
||||
|
||||
# Some alias
|
||||
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
|
||||
|
||||
@@ -15,35 +15,35 @@ def ceil_to_ue8m0(x: torch.Tensor):
|
||||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
padded_n = align(n, 128)
|
||||
padded_n = align(n, gran_k)
|
||||
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
|
||||
x_padded[:, :n] = x
|
||||
x_view = x_padded.view(m, -1, 128)
|
||||
x_view = x_padded.view(m, -1, gran_k)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
|
||||
|
||||
|
||||
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(0) % 128 == 0
|
||||
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(0) % gran_k == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(-1, 128, n)
|
||||
x_view = x.view(-1, gran_k, n)
|
||||
x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device)
|
||||
x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
@@ -58,3 +58,50 @@ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled, sf.squeeze()
|
||||
|
||||
|
||||
def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
|
||||
ax = x.abs().clamp_max(6.0)
|
||||
# {0, 0.5, 1, 1.5, 2, 3, 4, 6}
|
||||
# midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0
|
||||
boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0],
|
||||
device=x.device, dtype=ax.dtype)
|
||||
idx = torch.bucketize(ax, boundaries)
|
||||
code = idx.to(torch.uint8)
|
||||
sign = (x < 0) & (idx != 0)
|
||||
code = code | (sign.to(torch.uint8) << 3)
|
||||
return code # uint8, 0..15
|
||||
|
||||
|
||||
def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
assert n % 2 == 0
|
||||
padded_n = align(n, gran_k)
|
||||
x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device)
|
||||
x_padded[:, :n] = x
|
||||
x_view = x_padded.view(m, -1, gran_k)
|
||||
x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4)
|
||||
sf = x_amax / 6.0
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = x_view * (1.0 / sf.unsqueeze(2))
|
||||
codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n)
|
||||
codes2 = codes.view(m, padded_n // 2, 2)
|
||||
packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8
|
||||
return packed[:, :n // 2].contiguous(), sf
|
||||
|
||||
|
||||
def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor:
|
||||
assert a.dtype == torch.uint8
|
||||
assert a.dim() == 2
|
||||
m, n2 = a.shape
|
||||
n = n2 * 2
|
||||
assert (m % 2) == 0
|
||||
lo = a & 0x0F
|
||||
hi = (a >> 4) & 0x0F
|
||||
codes = torch.empty((m, n), device=a.device, dtype=torch.uint8)
|
||||
codes[:, 0::2], codes[:, 1::2] = lo, hi
|
||||
codes_t = codes.transpose(0, 1).contiguous()
|
||||
codes2 = codes_t.view(n, m // 2, 2)
|
||||
out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4)
|
||||
return out.contiguous()
|
||||
@@ -1,12 +1,13 @@
|
||||
import enum
|
||||
import random
|
||||
import torch
|
||||
from typing import Generator, List
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
from deep_gemm.testing import get_arch_major
|
||||
from deep_gemm.utils import (
|
||||
align, ceil_div,
|
||||
per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8,
|
||||
per_token_cast_to_fp4, transpose_packed_fp4,
|
||||
get_mk_alignment_for_contiguous_layout
|
||||
)
|
||||
|
||||
@@ -35,6 +36,51 @@ class MajorTypeAB(enum.Enum):
|
||||
|
||||
def is_mn_major(self):
|
||||
return self.value == 1
|
||||
|
||||
|
||||
class QuantConfig:
|
||||
_legacy_quant_config = (128, 128, False, False)
|
||||
|
||||
def __init__(self, value: Tuple[int, int, bool, bool] = _legacy_quant_config):
|
||||
self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b = value
|
||||
|
||||
def print(self):
|
||||
print(f' > Testing with gran_k_a={self.gran_k_a}, gran_k_b={self.gran_k_b}, '
|
||||
f'is_fp4_a={self.is_fp4_a}, is_fp4_b={self.is_fp4_b}')
|
||||
|
||||
def is_legacy(self) -> bool:
|
||||
return (self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b) == self._legacy_quant_config
|
||||
|
||||
def get_recipes(self, is_wgrad: bool = False) -> Tuple[Tuple, Tuple, Tuple]:
|
||||
recipe, recipe_a, recipe_b = None, None, None
|
||||
if self.is_legacy():
|
||||
recipe = (1, 1, 128) if is_wgrad else None
|
||||
else:
|
||||
recipe_a = (1, self.gran_k_a)
|
||||
recipe_b = (1, self.gran_k_b) if self.is_fp4_b or is_wgrad else (self.gran_k_b, self.gran_k_b)
|
||||
return recipe, recipe_a, recipe_b
|
||||
|
||||
def max_diff(self) -> float:
|
||||
if self.is_fp4_a and self.is_fp4_b:
|
||||
return 0.02
|
||||
if self.is_fp4_a or self.is_fp4_b:
|
||||
return 0.01
|
||||
return 0.001
|
||||
|
||||
@staticmethod
|
||||
def get_list_from_dtype(dtype: torch.dtype) -> List:
|
||||
if dtype == torch.bfloat16:
|
||||
return [None]
|
||||
quant_config_list = [QuantConfig()]
|
||||
if get_arch_major() == 10:
|
||||
quant_config_list.append(QuantConfig((128, 32, False, True)))
|
||||
return quant_config_list
|
||||
|
||||
|
||||
def reset_seed(seed: int = 0):
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
|
||||
def get_ue8m0_usage(kernel_type: KernelType) -> bool:
|
||||
@@ -60,9 +106,14 @@ def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator:
|
||||
yield major_a, major_b
|
||||
|
||||
|
||||
def get_psum_layout_usage() -> tuple:
|
||||
return (False, True) if get_arch_major() == 10 else (False, )
|
||||
|
||||
|
||||
def enumerate_normal(dtype: torch.dtype) -> Generator:
|
||||
assert dtype in (torch.float8_e4m3fn, torch.bfloat16)
|
||||
|
||||
quant_config_list = QuantConfig.get_list_from_dtype(dtype)
|
||||
fp32_output_nk = [(256, 7168), (129280, 7168)]
|
||||
bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]
|
||||
m_fwd_list, m_bwd_list = [1, 128, 4096], [4096, ]
|
||||
@@ -73,39 +124,61 @@ def enumerate_normal(dtype: torch.dtype) -> Generator:
|
||||
nk_list += fp32_output_nk
|
||||
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
# Forward
|
||||
for m in m_fwd_list:
|
||||
for i in range(len(nk_list)):
|
||||
n, k = nk_list[i]
|
||||
out_dtype = torch.bfloat16 if i < len(bf16_output_nk) else torch.float
|
||||
yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype
|
||||
for quant_config in quant_config_list:
|
||||
if len(quant_config_list) > 1:
|
||||
quant_config.print()
|
||||
reset_seed()
|
||||
|
||||
# Backward
|
||||
for m in m_bwd_list:
|
||||
for n, k in nk_list:
|
||||
override_major = MajorTypeAB.MNMajor
|
||||
override_kernel_type = kernel_type
|
||||
if get_arch_major() == 9 and dtype == torch.float8_e4m3fn:
|
||||
override_major = MajorTypeAB.KMajor
|
||||
override_kernel_type = KernelType.Kernel1D1D
|
||||
yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad
|
||||
yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad
|
||||
yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad
|
||||
# Forward
|
||||
for m in m_fwd_list:
|
||||
for i in range(len(nk_list)):
|
||||
n, k = nk_list[i]
|
||||
out_dtype = torch.bfloat16 if i < len(bf16_output_nk) else torch.float
|
||||
yield kernel_type, quant_config, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype
|
||||
|
||||
# Backward
|
||||
for m in m_bwd_list:
|
||||
for n, k in nk_list:
|
||||
override_major = MajorTypeAB.MNMajor
|
||||
override_kernel_type = kernel_type
|
||||
if get_arch_major() == 9 and dtype == torch.float8_e4m3fn:
|
||||
override_major = MajorTypeAB.KMajor
|
||||
override_kernel_type = KernelType.Kernel1D1D
|
||||
yield kernel_type, quant_config, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad
|
||||
yield override_kernel_type, quant_config, n, m, k, override_major, override_major, True, torch.float # Wgrad
|
||||
yield override_kernel_type, quant_config, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad
|
||||
|
||||
|
||||
def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator:
|
||||
quant_config_list = QuantConfig.get_list_from_dtype(dtype)
|
||||
m_group_list = [(4, 8192), (8, 4096)]
|
||||
n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)]
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)):
|
||||
for major_a, major_b in get_major_ab(False, get_arch_major() != 9 or dtype != torch.float8_e4m3fn):
|
||||
yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b
|
||||
for quant_config in quant_config_list:
|
||||
if len(quant_config_list) > 1:
|
||||
quant_config.print()
|
||||
for use_psum_layout in get_psum_layout_usage():
|
||||
reset_seed()
|
||||
for num_groups, expected_m_per_group in m_group_list:
|
||||
for n, k in n_k_list:
|
||||
for major_a, major_b in get_major_ab(False, get_arch_major() != 9 or dtype != torch.float8_e4m3fn):
|
||||
yield kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout
|
||||
|
||||
|
||||
def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator:
|
||||
quant_config_list = QuantConfig.get_list_from_dtype(dtype)
|
||||
max_m = 4096
|
||||
m_group_list = [(6, 1024), (32, 192), (32, 50)]
|
||||
n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)]
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
|
||||
for n, k in ((4096, 7168), (7168, 2048), ):
|
||||
yield kernel_type, num_groups, max_m, m, n, k
|
||||
for quant_config in quant_config_list:
|
||||
if len(quant_config_list) > 1:
|
||||
quant_config.print()
|
||||
for use_psum_layout in get_psum_layout_usage():
|
||||
reset_seed()
|
||||
for num_groups, m in m_group_list:
|
||||
for n, k in n_k_list:
|
||||
yield kernel_type, quant_config, num_groups, max_m, m, n, k, use_psum_layout
|
||||
|
||||
|
||||
def enumerate_k_grouped_contiguous(dtype: torch.dtype):
|
||||
@@ -145,11 +218,46 @@ def enumerate_transpose():
|
||||
yield mn + delta, k
|
||||
|
||||
|
||||
def cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool,
|
||||
use_ue8m0: bool, use_block_cast_for_fp8: bool = False):
|
||||
if is_fp4:
|
||||
x_fp4 = per_token_cast_to_fp4(x, use_ue8m0=use_ue8m0, gran_k=gran_k)
|
||||
x = x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1])
|
||||
else:
|
||||
x_fp8 = per_block_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \
|
||||
else per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k)
|
||||
x = x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1])
|
||||
return x
|
||||
|
||||
|
||||
def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool,
|
||||
use_ue8m0: bool, use_block_cast_for_fp8: bool = False):
|
||||
num_groups, mn, k = x.size()
|
||||
if is_fp4:
|
||||
x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.uint8) if major.is_k_major() else \
|
||||
torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.uint8),
|
||||
torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
x_i_fp4 = per_token_cast_to_fp4(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k)
|
||||
x_fp4[0][i], x_fp4[1][i] = x_i_fp4 if major.is_k_major() else (transpose_packed_fp4(x_i_fp4[0]), x_i_fp4[1])
|
||||
x = x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1])
|
||||
else:
|
||||
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn),
|
||||
torch.empty((num_groups, ceil_div(mn, gran_k), ceil_div(k, gran_k)), device='cuda', dtype=torch.float) if use_block_cast_for_fp8 \
|
||||
else torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
x_fp8[0][i], x_fp8[1][i] = per_block_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \
|
||||
else per_token_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k)
|
||||
x = x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1])
|
||||
return x
|
||||
|
||||
|
||||
def generate_normal(m: int, n: int, k: int,
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB,
|
||||
accumulate: bool, out_dtype: torch.dtype,
|
||||
kernel_type: KernelType,
|
||||
use_ue8m0: bool = False, use_bf16: bool = False):
|
||||
use_ue8m0: bool = False, use_bf16: bool = False,
|
||||
quant_config: Optional[QuantConfig] = None):
|
||||
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \
|
||||
@@ -161,25 +269,28 @@ def generate_normal(m: int, n: int, k: int,
|
||||
a = a if major_a.is_k_major() else a.T.contiguous().T
|
||||
b = b if major_b.is_k_major() else b.T.contiguous().T
|
||||
return a, b, c, d, ref_d
|
||||
|
||||
quant_config = QuantConfig() if quant_config is None else quant_config
|
||||
a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0)
|
||||
b = cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0,
|
||||
use_block_cast_for_fp8=not (kernel_type.is_1d1d() and accumulate))
|
||||
|
||||
a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0)
|
||||
b_fp8 = per_token_cast_to_fp8(b, use_ue8m0=use_ue8m0) if kernel_type.is_1d1d() and accumulate \
|
||||
else per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
|
||||
a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1])
|
||||
b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1])
|
||||
return a_fp8, b_fp8, c, d, ref_d
|
||||
return a, b, c, d, ref_d
|
||||
|
||||
|
||||
def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int,
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB,
|
||||
use_ue8m0: bool = False, use_bf16: bool = False):
|
||||
use_ue8m0: bool = False, use_bf16: bool = False,
|
||||
use_psum_layout: bool = False,
|
||||
quant_config: Optional[QuantConfig] = None):
|
||||
actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)]
|
||||
aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms]
|
||||
m = sum(aligned_ms)
|
||||
|
||||
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||
m_indices = torch.empty(m, device='cuda', dtype=torch.int32)
|
||||
grouped_layout = torch.empty(num_groups, device='cuda', dtype=torch.int32) if use_psum_layout \
|
||||
else torch.empty(m, device='cuda', dtype=torch.int32)
|
||||
d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
@@ -187,48 +298,61 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n:
|
||||
for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)):
|
||||
actual_end = start + actual_m
|
||||
aligned_end = start + aligned_m
|
||||
m_indices[start:actual_end] = i
|
||||
m_indices[actual_end:aligned_end] = -1
|
||||
ref_d[start:aligned_end] = a[start:aligned_end] @ b[i].t()
|
||||
if use_psum_layout:
|
||||
grouped_layout[i] = actual_end
|
||||
else:
|
||||
grouped_layout[start: actual_end] = i
|
||||
grouped_layout[actual_end: aligned_end] = -1
|
||||
a[actual_end: aligned_end] = 0
|
||||
ref_d[start: aligned_end] = a[start: aligned_end] @ b[i].t()
|
||||
start = aligned_end
|
||||
ref_d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_d), ref_d)
|
||||
|
||||
if use_bf16:
|
||||
b = b if major_b.is_k_major() else b.mT.contiguous().mT
|
||||
return m, a, b, m_indices, d, ref_d
|
||||
return m, a, b, grouped_layout, d, ref_d
|
||||
|
||||
assert major_a.is_k_major()
|
||||
a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0)
|
||||
b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn),
|
||||
torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float))
|
||||
quant_config = QuantConfig() if quant_config is None else quant_config
|
||||
a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0)
|
||||
b = grouped_cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True)
|
||||
|
||||
return m, a, b, grouped_layout, d, ref_d
|
||||
|
||||
|
||||
def layout_masked_to_psum(x: torch.Tensor, psum_m: torch.Tensor):
|
||||
num_groups, max_m, _ = x.size()
|
||||
x_psum = torch.empty_like(x).view(num_groups * max_m, -1)
|
||||
last_psum_m = 0
|
||||
for i in range(num_groups):
|
||||
b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0)
|
||||
b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].mT.contiguous().mT, b_fp8[1])
|
||||
return m, a_fp8, b_fp8, m_indices, d, ref_d
|
||||
x_psum[last_psum_m: psum_m[i]] = x[i, :psum_m[i] - last_psum_m]
|
||||
last_psum_m = align(psum_m[i], 128)
|
||||
return x_psum
|
||||
|
||||
|
||||
def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int,
|
||||
use_ue8m0: bool = False, use_bf16: bool = False):
|
||||
use_ue8m0: bool = False, use_bf16: bool = False,
|
||||
use_psum_layout: bool = False,
|
||||
quant_config: Optional[QuantConfig] = None):
|
||||
a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16)
|
||||
b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||
d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_d = torch.einsum('gmk,gnk->gmn', a, b)
|
||||
|
||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
psum_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
|
||||
psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], 128)) + masked_m[j]
|
||||
assert masked_m.amax().item() <= max_m
|
||||
|
||||
if use_bf16:
|
||||
return a, b, masked_m, d, ref_d
|
||||
return a, b, masked_m, psum_m, d, ref_d
|
||||
|
||||
a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float))
|
||||
b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0)
|
||||
b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0)
|
||||
quant_config = QuantConfig() if quant_config is None else quant_config
|
||||
a = grouped_cast_fp8_fp4_with_major(a, MajorTypeAB.KMajor, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0)
|
||||
b = grouped_cast_fp8_fp4_with_major(b, MajorTypeAB.KMajor, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True)
|
||||
|
||||
return a_fp8, b_fp8, masked_m, d, ref_d
|
||||
return a, b, masked_m, psum_m, d, ref_d
|
||||
|
||||
|
||||
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int],
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import dataclasses
|
||||
import random
|
||||
import torch
|
||||
from typing import Tuple
|
||||
from typing import Tuple, List
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm.testing import (
|
||||
bench_kineto,
|
||||
calc_diff, count_bytes,
|
||||
ignore_env, get_arch_major
|
||||
ignore_env, get_arch_major,
|
||||
test_filter
|
||||
)
|
||||
from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8
|
||||
|
||||
@@ -154,7 +156,7 @@ def test_mqa_logits():
|
||||
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
|
||||
logits = logits.masked_fill(neginf_mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
||||
assert diff < 1e-3, f'{diff=}'
|
||||
else:
|
||||
ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke, cost_only=True)
|
||||
|
||||
@@ -204,8 +206,6 @@ def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor,
|
||||
|
||||
|
||||
def test_paged_mqa_logits():
|
||||
# TODO: fully refactor with PyTest
|
||||
|
||||
print('Testing FP8 Paged MQA Logits:')
|
||||
max_model_len = 111 * 1000
|
||||
for is_context_lens_2d in (False, True):
|
||||
@@ -264,7 +264,7 @@ def test_paged_mqa_logits():
|
||||
else:
|
||||
t, clean_t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True),
|
||||
('fp8_paged_mqa_logits', 'clean_logits'))
|
||||
clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens)
|
||||
clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens)
|
||||
print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: '
|
||||
f'{tflops / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, '
|
||||
f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s', end='')
|
||||
@@ -273,6 +273,8 @@ def test_paged_mqa_logits():
|
||||
print()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
@@ -9,7 +9,7 @@ from deep_gemm.testing import (
|
||||
calc_diff, count_bytes
|
||||
)
|
||||
from generators import (
|
||||
get_arch_major,
|
||||
get_arch_major, layout_masked_to_psum, align,
|
||||
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
|
||||
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
|
||||
)
|
||||
@@ -18,11 +18,7 @@ from generators import (
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
scores = []
|
||||
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16):
|
||||
# TODO: support accumulation for SM90 BF16 GEMM
|
||||
if get_arch_major() == 9 and accumulate:
|
||||
continue
|
||||
|
||||
for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
||||
@@ -56,29 +52,30 @@ def test_gemm() -> None:
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing m-grouped contiguous GEMM:')
|
||||
|
||||
for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(torch.bfloat16):
|
||||
for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(torch.bfloat16):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
|
||||
for test_alias in (False, True):
|
||||
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True)
|
||||
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b,
|
||||
use_bf16=True, use_psum_layout=use_psum_layout)
|
||||
func_name = f"m_grouped_bf16_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous"
|
||||
if test_alias:
|
||||
assert major_a.is_k_major()
|
||||
b = b if major_b.is_k_major() else b.mT
|
||||
assert a[0].is_contiguous() and b[0].is_contiguous()
|
||||
getattr(deep_gemm, func_name)(a, b, d, m_indices)
|
||||
d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d)
|
||||
getattr(deep_gemm, func_name)(a, b, d, grouped_layout, use_psum_layout=use_psum_layout)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}'
|
||||
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True)
|
||||
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b,
|
||||
use_bf16=True, use_psum_layout=use_psum_layout)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, m_indices)
|
||||
deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout, use_psum_layout=use_psum_layout)
|
||||
|
||||
t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): '
|
||||
print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, psum={use_psum_layout}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
|
||||
@@ -89,29 +86,52 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing m-grouped masked GEMM:')
|
||||
|
||||
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
|
||||
for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.bfloat16):
|
||||
# Test correctness
|
||||
for i in range(10):
|
||||
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True)
|
||||
deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group)
|
||||
for _, _, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.bfloat16):
|
||||
num_tests = 8
|
||||
sum_t, max_t = 0, 0
|
||||
sum_ops, sum_bytes = 0, 0
|
||||
|
||||
for i in range(num_tests):
|
||||
a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k,
|
||||
use_bf16=True, use_psum_layout=use_psum_layout)
|
||||
if use_psum_layout:
|
||||
a_psum = layout_masked_to_psum(a, psum_m)
|
||||
d_psum = layout_masked_to_psum(d, psum_m)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
if use_psum_layout:
|
||||
deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a_psum, b, d_psum, psum_m,
|
||||
use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group)
|
||||
else:
|
||||
deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group)
|
||||
|
||||
test_func()
|
||||
for j in range(num_groups):
|
||||
diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()])
|
||||
if masked_m[j].item() == 0:
|
||||
continue
|
||||
if use_psum_layout:
|
||||
d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]]
|
||||
else:
|
||||
d_slice = d[j, :masked_m[j].item()]
|
||||
diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()])
|
||||
assert diff < 1e-5, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
||||
|
||||
# Construct full cases
|
||||
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group)
|
||||
# Test performance with fixed shapes
|
||||
valid_m = masked_m.sum().item()
|
||||
t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True)
|
||||
|
||||
# Test performance with fixed shapes
|
||||
valid_m = masked_m.sum().item()
|
||||
t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s')
|
||||
sum_t += t
|
||||
max_t = max(max_t, t)
|
||||
sum_ops += 2 * valid_m * n * k
|
||||
sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)
|
||||
|
||||
print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, '
|
||||
f'psum={1 if use_psum_layout else 0}): '
|
||||
f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | '
|
||||
f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | '
|
||||
f'{sum_bytes / sum_t / 1e9:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
@@ -148,7 +168,7 @@ def test_k_grouped_gemm_contiguous() -> None:
|
||||
|
||||
def test_cublaslt_gemm() -> None:
|
||||
print('Testing cuBLASLt GEMM:')
|
||||
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16):
|
||||
for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
||||
@@ -159,7 +179,8 @@ def test_cublaslt_gemm() -> None:
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 6e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})'
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), 'nvjet', suppress_kineto_output=True,)
|
||||
t_nvjet, t_gemv, t_gemm = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'gemv', 'gemm'), suppress_kineto_output=True)
|
||||
t = t_nvjet + t_gemv + t_gemm
|
||||
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:5.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
|
||||
@@ -80,7 +80,6 @@ def test_bhd_hdr_bhr():
|
||||
print()
|
||||
|
||||
|
||||
@test_filter(lambda: get_arch_major() >= 10)
|
||||
def test_fp8_bhr_hdr_bhd(use_ue8m0: bool = True):
|
||||
print('Testing FP8 "bhr, hdr -> bhd":')
|
||||
for h, r, d in [(8, 4096, 1024)]:
|
||||
|
||||
@@ -1,175 +0,0 @@
|
||||
import copy
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm.testing import (
|
||||
bench_kineto,
|
||||
calc_diff, count_bytes,
|
||||
ignore_env, get_arch_major
|
||||
)
|
||||
|
||||
from generators import (
|
||||
KernelType, get_ue8m0_usage,
|
||||
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
|
||||
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
|
||||
)
|
||||
|
||||
|
||||
@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9)
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
scores = []
|
||||
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
||||
acc_opt = f'acc={int(accumulate)}'
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
recipe = (1, 1, 128) if kernel_type.is_1d1d() and accumulate else None
|
||||
|
||||
for test_alias in (False, True):
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
|
||||
func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}'
|
||||
if test_alias:
|
||||
a = a if major_a.is_k_major() else (a[0].T, a[1].T)
|
||||
b = b if major_b.is_k_major() else (b[0].T, b[1].T)
|
||||
assert a[0].is_contiguous() and b[0].is_contiguous()
|
||||
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, '
|
||||
f'{diff:.5f}, alias={test_alias}')
|
||||
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
|
||||
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe),
|
||||
'fp8_gemm', suppress_kineto_output=True)
|
||||
cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True)
|
||||
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
|
||||
f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS')
|
||||
if cublas_t > 0:
|
||||
scores.append((cublas_t + split_k_t) / t)
|
||||
print(f"Average speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n")
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing m-grouped contiguous GEMM:')
|
||||
|
||||
for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
|
||||
for test_alias in (False, True):
|
||||
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0)
|
||||
func_name = f"m_grouped_fp8_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous"
|
||||
if test_alias:
|
||||
assert major_a.is_k_major()
|
||||
b = b if major_b.is_k_major() else (b[0].mT, b[1].mT)
|
||||
assert a[0].is_contiguous() and b[0].is_contiguous()
|
||||
getattr(deep_gemm, func_name)(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}'
|
||||
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing m-grouped masked GEMM:')
|
||||
|
||||
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
|
||||
for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.float8_e4m3fn):
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
|
||||
# Test correctness
|
||||
for i in range(10):
|
||||
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0)
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
for j in range(num_groups):
|
||||
if masked_m[j].item() == 0:
|
||||
continue
|
||||
diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()])
|
||||
assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}'
|
||||
|
||||
# Construct full cases
|
||||
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
|
||||
# Test performance with fixed shapes
|
||||
valid_m = masked_m.sum().item()
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_k_grouped_gemm_contiguous() -> None:
|
||||
print('Testing k-grouped contiguous GEMM:')
|
||||
|
||||
k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \
|
||||
else deep_gemm.k_grouped_fp8_gemm_tn_contiguous
|
||||
for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.float8_e4m3fn):
|
||||
use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D)
|
||||
|
||||
for test_empty_groups in (False, True):
|
||||
new_ks = copy.deepcopy(ks)
|
||||
if test_empty_groups and len(ks) > 1:
|
||||
new_ks[random.randint(0, num_groups - 1)] = 0
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0)
|
||||
new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda')
|
||||
k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c)
|
||||
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}'
|
||||
|
||||
# Test performance
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0)
|
||||
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print('Library path:')
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
test_k_grouped_gemm_contiguous()
|
||||
207
tests/test_fp8_fp4.py
Normal file
207
tests/test_fp8_fp4.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import copy
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm.testing import (
|
||||
bench_kineto,
|
||||
calc_diff, count_bytes,
|
||||
ignore_env, get_arch_major
|
||||
)
|
||||
|
||||
from generators import (
|
||||
KernelType, get_ue8m0_usage, layout_masked_to_psum, align,
|
||||
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
|
||||
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
|
||||
)
|
||||
|
||||
|
||||
@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9)
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
scores = []
|
||||
for kernel_type, quant_config, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
||||
acc_opt = f'acc={int(accumulate)}'
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
recipe, recipe_a, recipe_b = quant_config.get_recipes(is_wgrad=(kernel_type.is_1d1d() and accumulate))
|
||||
|
||||
for test_alias in (False, True):
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config)
|
||||
func_name = f'fp8_fp4_gemm_{major_opt.lower() if test_alias else "nt"}'
|
||||
if test_alias:
|
||||
a = a if major_a.is_k_major() else (a[0].T, a[1].T)
|
||||
b = b if major_b.is_k_major() else (b[0].T, b[1].T)
|
||||
assert a[0].is_contiguous() and b[0].is_contiguous()
|
||||
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < quant_config.max_diff(), (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, '
|
||||
f'{diff:.5f}, alias={test_alias}')
|
||||
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config)
|
||||
t = bench_kineto(lambda: deep_gemm.fp8_fp4_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b),
|
||||
'fp8_gemm', suppress_kineto_output=True)
|
||||
cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) \
|
||||
if not quant_config.is_fp4_a and not quant_config.is_fp4_b else (0, 0)
|
||||
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
|
||||
f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS')
|
||||
if cublas_t > 0:
|
||||
scores.append((cublas_t + split_k_t) / t)
|
||||
print(f"Average FP8xFP8 GEMM speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n")
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing m-grouped contiguous GEMM:')
|
||||
|
||||
for kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
recipe, recipe_a, recipe_b = quant_config.get_recipes()
|
||||
|
||||
for test_alias in (False, True):
|
||||
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b,
|
||||
use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout,
|
||||
quant_config=quant_config)
|
||||
func_name = f"m_grouped_fp8_fp4_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous"
|
||||
if test_alias:
|
||||
assert major_a.is_k_major()
|
||||
b = b if major_b.is_k_major() else (b[0].mT, b[1].mT)
|
||||
assert a[0].is_contiguous() and b[0].is_contiguous()
|
||||
getattr(deep_gemm, func_name)(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout,
|
||||
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < quant_config.max_diff(), f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}'
|
||||
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b,
|
||||
use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout,
|
||||
quant_config=quant_config)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout,
|
||||
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}, psum={use_psum_layout}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing m-grouped masked GEMM:')
|
||||
|
||||
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
|
||||
for kernel_type, quant_config, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.float8_e4m3fn):
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
recipe, recipe_a, recipe_b = quant_config.get_recipes()
|
||||
|
||||
num_tests = 8
|
||||
sum_t, max_t = 0, 0
|
||||
sum_ops, sum_bytes = 0, 0
|
||||
|
||||
for i in range(num_tests):
|
||||
a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k,
|
||||
use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout,
|
||||
quant_config=quant_config)
|
||||
if use_psum_layout:
|
||||
a_psum = (layout_masked_to_psum(a[0], psum_m), layout_masked_to_psum(a[1], psum_m))
|
||||
d_psum = layout_masked_to_psum(d, psum_m)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
if use_psum_layout:
|
||||
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a_psum, b, d_psum, psum_m, disable_ue8m0_cast=disable_ue8m0_cast,
|
||||
use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group,
|
||||
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b)
|
||||
else:
|
||||
deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast,
|
||||
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b)
|
||||
|
||||
test_func()
|
||||
for j in range(num_groups):
|
||||
if masked_m[j].item() == 0:
|
||||
continue
|
||||
if use_psum_layout:
|
||||
d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]]
|
||||
else:
|
||||
d_slice = d[j, :masked_m[j].item()]
|
||||
diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()])
|
||||
assert diff < quant_config.max_diff(), f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}'
|
||||
|
||||
# Test performance with fixed shapes
|
||||
valid_m = masked_m.sum().item()
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
|
||||
sum_t += t
|
||||
max_t = max(max_t, t)
|
||||
sum_ops += 2 * valid_m * n * k
|
||||
sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)
|
||||
|
||||
print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, '
|
||||
f'{kernel_opt}, psum={1 if use_psum_layout else 0}): '
|
||||
f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | '
|
||||
f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | '
|
||||
f'{sum_bytes / sum_t / 1e9:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9)
|
||||
def test_k_grouped_gemm_contiguous() -> None:
|
||||
print('Testing k-grouped contiguous GEMM:')
|
||||
|
||||
k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \
|
||||
else deep_gemm.k_grouped_fp8_gemm_tn_contiguous
|
||||
for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.float8_e4m3fn):
|
||||
use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D)
|
||||
|
||||
for test_empty_groups in (False, True):
|
||||
new_ks = copy.deepcopy(ks)
|
||||
if test_empty_groups and len(ks) > 1:
|
||||
new_ks[random.randint(0, num_groups - 1)] = 0
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0)
|
||||
new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda')
|
||||
k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c)
|
||||
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}'
|
||||
|
||||
# Test performance
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0)
|
||||
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print('Library path:')
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
test_k_grouped_gemm_contiguous()
|
||||
57
tests/test_hyperconnection.py
Normal file
57
tests/test_hyperconnection.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import random
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm.testing import (
|
||||
test_filter,
|
||||
bench_kineto,
|
||||
calc_diff, count_bytes
|
||||
)
|
||||
from deep_gemm.utils import align
|
||||
from generators import get_arch_major
|
||||
|
||||
|
||||
@test_filter(lambda: get_arch_major() >= 9)
|
||||
def test_hc_prenorm_gemm() -> None:
|
||||
# Needs TF32 precision for PyTorch GEMMs
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
print('Testing hyperconnection prenorm GEMM:')
|
||||
for m in (13, 137, 4096, 8192):
|
||||
for n, k in [(24, 28672), (24, 7680), (24, 7168)]:
|
||||
for num_splits in [None, 16]:
|
||||
a = torch.randn((m, k), dtype=torch.bfloat16, device='cuda')
|
||||
b = torch.randn((n, k), dtype=torch.float, device='cuda')
|
||||
d = torch.empty((m, n), dtype=torch.float, device='cuda') if num_splits is None else \
|
||||
torch.empty((num_splits, m, n), dtype=torch.float, device='cuda')
|
||||
s = torch.empty((m, ), dtype=torch.float, device='cuda') if num_splits is None else \
|
||||
torch.empty((num_splits, m), dtype=torch.float, device='cuda')
|
||||
deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits)
|
||||
final_d = d if num_splits is None else d.sum(0)
|
||||
final_s = s if num_splits is None else s.sum(0)
|
||||
|
||||
ref_d = a.float() @ b.T
|
||||
ref_s = a.float().square().sum(-1)
|
||||
|
||||
diff = max(calc_diff(final_d, ref_d), calc_diff(final_s, ref_s))
|
||||
assert diff < 1e-8, f'{m=}, {n=}, {k=}, {diff:.10f}'
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits), 'tf32_hc_prenorm_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, num_splits={(num_splits or 0):2}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{count_bytes(a, b, d, s) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print('Library path:')
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_hc_prenorm_gemm()
|
||||
@@ -13,7 +13,7 @@ from generators import (
|
||||
|
||||
def test_m_grouped_gemm_contiguous_tl() -> None:
|
||||
print('Testing m-grouped contiguous Triton GEMM:')
|
||||
for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(torch.bfloat16):
|
||||
for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, _ in enumerate_m_grouped_contiguous(torch.bfloat16):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user