Files
DeepGEMM/csrc/jit_kernels/heuristics/common.hpp
Chenggang Zhao 7f2a703ed5 [Public release 26/04] Introducing Mega MoE, FP4 Indexer and other features/fixes (#304)
* Merge with private repo

* Update README

* Update README

* Update README

* Add PyTorch requirements

* Fix sync scopes for MQA logits (#256)

* Update README
2026-04-17 09:45:14 +08:00

55 lines
1.8 KiB
C++

#pragma once
#include <unordered_set>
#include <deep_gemm/common/types.cuh>
#include "config.hpp"
#include "runtime.hpp"
#include "../../utils/layout.hpp"
#include "../../utils/system.hpp"
namespace deep_gemm {
template <typename ArchSpec>
static GemmConfig get_best_config(const GemmDesc& desc) {
desc.check_validity();
// Choose the best layout
const auto layout_candidates = ArchSpec::get_layout_candidates(desc);
DG_HOST_ASSERT(not layout_candidates.empty());
auto layout = layout_candidates[0];
auto layout_info = ArchSpec::get_layout_info(desc, layout);
for (int i = 1; i < static_cast<int>(layout_candidates.size()); ++ i) {
const auto candidate_info = ArchSpec::get_layout_info(desc, layout_candidates[i]);
if (ArchSpec::compare(candidate_info, layout_info))
layout = layout_candidates[i], layout_info = candidate_info;
}
// Infer other configs
const auto storage_config = ArchSpec::get_storage_config(desc, layout);
const auto pipeline_config = ArchSpec::get_pipeline_config(desc, layout, storage_config);
const auto launch_config = ArchSpec::get_launch_config(desc, layout);
const auto gemm_config = GemmConfig {
.layout = layout,
.storage_config = storage_config,
.pipeline_config = pipeline_config,
.launch_config = launch_config
};
// Print configs for the first time
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
std::stringstream ss;
ss << desc;
const auto key = ss.str();
static std::unordered_set<std::string> printed;
if (printed.count(key) == 0) {
std::cout << desc << ": " << gemm_config << ", " << layout_info << std::endl;
printed.insert(key);
}
}
return gemm_config;
}
} // namespace deep_gemm