Add various optimizations and Mega MoE benchmarks (#316)

* Merge with private repo

* Add Mega MoE Benchmark

* Minor fix

* Update

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
Zhean Xu
2026-04-24 18:41:37 +08:00
committed by GitHub
parent 7f2a703ed5
commit 891d57b4db
21 changed files with 1276 additions and 372 deletions

View File

@@ -9,15 +9,15 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
## News
- 2026.04.16: Mega MoE, FP8xFP4 GEMM, FP4 Indexer, PDL, faster JIT compilation and more.
- Performance comparison will be posted later.
- Please see [#304](https://github.com/deepseek-ai/DeepGEMM/pull/304) for more details.
- Please see [#304](https://github.com/deepseek-ai/DeepGEMM/pull/304) for more details.
- For Mega MoE benchmarks, refer to [#316](https://github.com/deepseek-ai/DeepGEMM/pull/316).
- 2025.09.28: DeepGEMM now supports scoring kernels (weighted ReLU MQA logits) for the lightning indexer for DeepSeek v3.2.
- Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details.
- Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details.
- 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module.
- NVRTC and post-compilation SASS optimization are all disabled.
- NVRTC will be supported later.
- As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported.
- Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details.
- NVRTC and post-compilation SASS optimization are all disabled.
- NVRTC will be supported later.
- As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported.
- Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details.
- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details.
- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases).
- 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details.
@@ -30,9 +30,9 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
- Python 3.8 or higher
- Compilers with C++20 support
- CUDA Toolkit:
- CUDA 12.3 or higher for SM90
- **We highly recommend 12.9 or higher for the best performance**
- CUDA 12.9 or higher for SM100
- CUDA 12.3 or higher for SM90
- **We highly recommend 12.9 or higher for the best performance**
- CUDA 12.9 or higher for SM100
- PyTorch 2.1 or higher
- CUTLASS 4.0 or higher (could be cloned by Git submodule)
- `{fmt}` library (could be cloned by Git submodule)
@@ -159,30 +159,30 @@ The library provides some utility functions besides the above kernels:
The library also provides some environment variables, which may be useful:
- General
- `DG_JIT_DEBUG`: `0` or `1`, print JIT debugging information, `0` by default
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
- `DG_JIT_DEBUG`: `0` or `1`, print JIT debugging information, `0` by default
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
- JIT cache
- `DG_JIT_CACHE_DIR`: string, cache directory for compiled kernels, `$HOME/.deep_gemm` by default
- `DG_JIT_CACHE_DIR`: string, cache directory for compiled kernels, `$HOME/.deep_gemm` by default
- Compiler selection
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC (faster compilation, may have lower performance for some cases), `0` by default
- `DG_JIT_NVCC_COMPILER`: string, NVCC compiler path; defaults to `torch.utils.cpp_extension.CUDA_HOME`
- `DG_JIT_CPP_STANDARD`: integer, C++ standard version, `20` by default
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC (faster compilation, may have lower performance for some cases), `0` by default
- `DG_JIT_NVCC_COMPILER`: string, NVCC compiler path; defaults to `torch.utils.cpp_extension.CUDA_HOME`
- `DG_JIT_CPP_STANDARD`: integer, C++ standard version, `20` by default
- Compiler output
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print compilation commands, `0` by default
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS output, `0` by default
- `DG_JIT_PTXAS_CHECK`: `0` or `1`, assert no local memory usage in compiled kernels, `0` by default
- `DG_JIT_PRINT_LOAD_TIME`: `0` or `1`, print kernel load time, `0` by default
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print compilation commands, `0` by default
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS output, `0` by default
- `DG_JIT_PTXAS_CHECK`: `0` or `1`, assert no local memory usage in compiled kernels, `0` by default
- `DG_JIT_PRINT_LOAD_TIME`: `0` or `1`, print kernel load time, `0` by default
- Debug and profiling
- `DG_JIT_WITH_LINEINFO`: `0` or `1`, embed source line info for profiling tools, `0` by default
- `DG_JIT_DUMP_ASM`: `0` or `1`, dump both PTX and SASS, `0` by default
- `DG_JIT_DUMP_PTX`: `0` or `1`, dump PTX output, `0` by default
- `DG_JIT_DUMP_SASS`: `0` or `1`, dump SASS output, `0` by default
- `DG_COMM_KERNEL_DEBUG`: `0` or `1`, zero symmetric buffer before each Mega MoE call for debugging, `0` by default
- `DG_USE_NVIDIA_TOOLS`: `0` or `1`, skip internal profiling when running under external NVIDIA tools, `0` by default
- `DG_JIT_WITH_LINEINFO`: `0` or `1`, embed source line info for profiling tools, `0` by default
- `DG_JIT_DUMP_ASM`: `0` or `1`, dump both PTX and SASS, `0` by default
- `DG_JIT_DUMP_PTX`: `0` or `1`, dump PTX output, `0` by default
- `DG_JIT_DUMP_SASS`: `0` or `1`, dump SASS output, `0` by default
- `DG_COMM_KERNEL_DEBUG`: `0` or `1`, zero symmetric buffer before each Mega MoE call for debugging, `0` by default
- `DG_USE_NVIDIA_TOOLS`: `0` or `1`, skip internal profiling when running under external NVIDIA tools, `0` by default
- Build options
- `DG_SKIP_CUDA_BUILD`: `0` or `1`, skip CUDA extension build during installation, `0` by default
- `DG_FORCE_BUILD`: `0` or `1`, force local build instead of downloading pre-built wheels, `0` by default
- `DG_JIT_USE_RUNTIME_API`: `0` or `1`, use CUDA Runtime API for kernel loading (requires CUDA runtime >= 12.8), `0` by default
- `DG_SKIP_CUDA_BUILD`: `0` or `1`, skip CUDA extension build during installation, `0` by default
- `DG_FORCE_BUILD`: `0` or `1`, force local build instead of downloading pre-built wheels, `0` by default
- `DG_JIT_USE_RUNTIME_API`: `0` or `1`, use CUDA Runtime API for kernel loading (requires CUDA runtime >= 12.8), `0` by default
For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.

View File

@@ -190,12 +190,13 @@ static torch::Tensor fp8_fp4_mqa_logits(const std::tuple<torch::Tensor, std::opt
return logits;
}
static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms) {
static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms, const std::optional<torch::Tensor>& indices) {
// NOTES: Only 2D context lens is supported for now
DG_HOST_ASSERT(context_lens.dim() == 2);
const bool is_context_lens_2d = true;
const int batch_size = context_lens.size(0);
const int next_n = context_lens.size(1);
const bool is_varlen = indices.has_value();
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
DG_HOST_ASSERT(context_lens.is_contiguous());
@@ -204,9 +205,16 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9 or arch_major == 10) {
if (is_varlen) {
const auto& indices_tensor = indices.value();
DG_HOST_ASSERT(arch_major == 10 and next_n == 1 and (block_kv == 64 or block_kv == 32));
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
DG_HOST_ASSERT(indices_tensor.is_contiguous());
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, true, indices_tensor.data_ptr<int>());
} else if (arch_major == 9 or arch_major == 10) {
DG_HOST_ASSERT(block_kv == 64 or (arch_major == 10 and block_kv == 32));
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d);
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, false, nullptr);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
@@ -222,7 +230,8 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
const torch::Tensor& schedule_meta,
const int& max_context_len,
const bool& clean_logits,
const at::ScalarType& logits_dtype) {
const at::ScalarType& logits_dtype,
const std::optional<torch::Tensor>& indices) {
const auto [q_fp, q_sf] = q;
const bool is_fp4 = q_sf.has_value();
@@ -321,6 +330,17 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
DG_HOST_ASSERT(block_table.stride(1) == 1);
DG_HOST_ASSERT(block_table.scalar_type() == torch::kInt);
// Check indices
const bool is_varlen = indices.has_value();
const auto arch_major = device_runtime->get_arch_major();
const auto indices_tensor = indices.value_or(torch::Tensor());
if (is_varlen) {
DG_HOST_ASSERT(arch_major == 10 and next_n == 1);
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
DG_HOST_ASSERT(indices_tensor.is_contiguous());
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
}
// Check schedule metadata
auto [_schedule_meta_size, _meta_info_size] = get_shape<2>(schedule_meta);
DG_HOST_ASSERT(_schedule_meta_size == num_sms + 1 and _meta_info_size == 2);
@@ -344,15 +364,14 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
DG_HOST_ASSERT(logits_dtype == torch::kFloat32 or logits_dtype == torch::kBFloat16);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (is_fp4 and arch_major == 10) {
sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, schedule_meta,
sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta,
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
aligned_max_context_len, block_table_stride, num_sms, split_kv);
is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv);
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) {
smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, schedule_meta,
smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta,
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
aligned_max_context_len, block_table_stride, num_sms, split_kv);
is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
@@ -386,10 +405,11 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& block_table,
const torch::Tensor& schedule_meta,
const int& max_context_len,
const bool& clean_logits) {
const bool& clean_logits,
const std::optional<torch::Tensor>& indices) {
return fp8_fp4_paged_mqa_logits(std::make_tuple(q, std::nullopt), fused_kv_cache, weights,
context_lens, block_table, schedule_meta,
max_context_len, clean_logits, torch::kFloat);
max_context_len, clean_logits, torch::kFloat, indices);
}
#endif
@@ -407,13 +427,15 @@ static void register_apis(pybind11::module_& m) {
py::arg("max_seqlen_k") = 0,
py::arg("logits_dtype") = torch::kFloat32);
m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata,
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"));
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"),
py::arg("indices") = std::nullopt);
m.def("fp8_fp4_paged_mqa_logits", &fp8_fp4_paged_mqa_logits,
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
py::arg("max_context_len"),
py::arg("clean_logits") = false,
py::arg("logits_dtype") = torch::kFloat32);
py::arg("logits_dtype") = torch::kFloat32,
py::arg("indices") = std::nullopt);
// Legacy API
m.def("fp8_mqa_logits", &fp8_mqa_logits,
py::arg("q"), py::arg("kv"), py::arg("weights"),
@@ -423,7 +445,8 @@ static void register_apis(pybind11::module_& m) {
m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits,
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
py::arg("max_context_len"), py::arg("clean_logits") = false);
py::arg("max_context_len"), py::arg("clean_logits") = false,
py::arg("indices") = std::nullopt);
#endif
}

View File

@@ -11,6 +11,10 @@
namespace deep_gemm::mega {
static int get_token_alignment_for_mega_moe() {
return layout::kLCMCandidateBlockM;
}
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
get_symm_buffer_size_for_mega_moe(
const int& num_ranks, const int& num_experts,
@@ -20,8 +24,7 @@ get_symm_buffer_size_for_mega_moe(
DG_HOST_ASSERT(num_experts % num_ranks == 0);
// Workspace bytes
const auto block_m = get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk, block_m);
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
// Layouts
const auto fp8_token_layout = layout::Data(hidden);
@@ -49,14 +52,20 @@ get_symm_buffer_size_for_mega_moe(
// Buffer configs
const auto num_max_pool_tokens = static_cast<int>(workspace.num_max_pool_tokens);
const auto num_padded_sf_pool_tokens = layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m);
int num_max_padded_sf_pool_tokens = 0;
for (int block_m: layout::kCandidateBlockM) {
num_max_padded_sf_pool_tokens = std::max(
num_max_padded_sf_pool_tokens,
layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m)
);
}
// L1 input buffer
const auto l1_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_pool_tokens,
input_topk_weights_buffer.get_end_ptr());
const auto l1_sf_buffer = layout::Buffer(
fp8_sf_layout, 1, num_padded_sf_pool_tokens,
fp8_sf_layout, 1, num_max_padded_sf_pool_tokens,
l1_token_buffer.get_end_ptr());
const auto l1_topk_weights_buffer = layout::Buffer(
l1_topk_weights_layout, 1, num_max_pool_tokens,
@@ -67,7 +76,7 @@ get_symm_buffer_size_for_mega_moe(
fp8_intermediate_token_layout, 1, num_max_pool_tokens,
l1_topk_weights_buffer.get_end_ptr());
const auto l2_sf_buffer = layout::Buffer(
fp8_intermediate_sf_layout, 1, num_padded_sf_pool_tokens,
fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens,
l2_token_buffer.get_end_ptr());
// Combine input buffer: BF16 tokens for cross-rank combine
@@ -77,7 +86,7 @@ get_symm_buffer_size_for_mega_moe(
// Check SF buffer requirements
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
DG_HOST_ASSERT(num_padded_sf_pool_tokens % 4 == 0);
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);
// Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer
// NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major
@@ -104,8 +113,8 @@ get_symm_buffer_size_for_mega_moe(
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto l1_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
{num_padded_sf_pool_tokens, hidden / 128},
{1, num_padded_sf_pool_tokens},
{num_max_padded_sf_pool_tokens, hidden / 128},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
@@ -113,8 +122,8 @@ get_symm_buffer_size_for_mega_moe(
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto l2_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
{num_padded_sf_pool_tokens, intermediate_hidden / 128},
{1, num_padded_sf_pool_tokens},
{num_max_padded_sf_pool_tokens, intermediate_hidden / 128},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
};
@@ -123,8 +132,9 @@ get_symm_buffer_size_for_mega_moe(
static void fp8_fp4_mega_moe(
const torch::Tensor& y,
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_,
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_,
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_tuple,
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_tuple,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const torch::Tensor& sym_buffer,
const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx,
const int& num_max_tokens_per_rank,
@@ -132,9 +142,10 @@ static void fp8_fp4_mega_moe(
const std::tuple<int, int, int>& recipe,
const std::string& activation,
const std::optional<float>& activation_clamp_opt,
const bool& fast_math) {
const auto [l1_weights, l1_weights_sf] = l1_weights_;
const auto [l2_weights, l2_weights_sf] = l2_weights_;
const bool& fast_math
) {
const auto [l1_weights, l1_weights_sf] = l1_weights_tuple;
const auto [l2_weights, l2_weights_sf] = l2_weights_tuple;
// Config checks
const auto num_tokens = static_cast<int>(y.size(0));
@@ -161,13 +172,20 @@ static void fp8_fp4_mega_moe(
DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden);
DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous());
// Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment
// Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment
constexpr int kGranMN = 1, kGranK = 32;
check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK,
num_experts_per_rank, true, false, torch::kInt);
check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK,
num_experts_per_rank, true, false, torch::kInt);
// Check stats counter
if (cumulative_local_expert_recv_stats.has_value()) {
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank);
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous());
}
// Check buffer bytes
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
const auto num_experts_ = num_experts_per_rank * num_ranks;
@@ -175,7 +193,7 @@ static void fp8_fp4_mega_moe(
num_ranks, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
true, "swiglu");
true, activation);
DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast<size_t>(num_required_bytes));
DG_HOST_ASSERT(num_experts == num_experts_);
@@ -189,6 +207,7 @@ static void fp8_fp4_mega_moe(
l2_acts, l2_acts_sf,
l1_weights, l2_weights,
l1_weights_sf, l2_weights_sf,
cumulative_local_expert_recv_stats,
sym_buffer_ptrs,
rank_idx, num_max_tokens_per_rank,
num_experts_per_rank,
@@ -207,7 +226,7 @@ static void fp8_fp4_mega_moe(
static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("get_block_m_for_mega_moe", &get_block_m_for_mega_moe);
m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe);
m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe);
m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe);
#endif

View File

@@ -55,38 +55,68 @@ struct MegaMoEConfig {
}
};
static int get_block_m_for_mega_moe(const int& num_ranks, const int& num_experts,
const int& num_max_tokens_per_rank, const int& num_topk) {
// TODO: compute based on configs
return 192;
static std::tuple<int, int, int, int> get_block_config_for_mega_moe(
const int& num_ranks, const int& num_experts,
const int& num_max_tokens_per_rank, const int& num_topk,
const int& num_tokens) {
const auto& [cluster_size, block_m, store_block_m, num_epilogue_warpgroups] = [&]() -> std::tuple<int, int, int, int> {
float num_expected_tokens_per_expert = static_cast<float>(num_tokens) * num_ranks * num_topk / num_experts;
if (num_expected_tokens_per_expert <= 8.5) {
// Really small token-per-expert (e.g. RL long-tail rollout), use the smallest block_m
return {2, 16, 8, 2};
} else if (num_expected_tokens_per_expert <= 16.5) {
// Small batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 128
return {2, 32, 16, 2};
} else if (num_expected_tokens_per_expert <= 32.5) {
// Medium batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 256
return {2, 64, 32, 1};
} else if (num_expected_tokens_per_expert <= 64.5) {
// Large batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 512
return {2, 96, 16, 2};
} else if (num_expected_tokens_per_expert <= 96.5) {
// Medium batch size, Medium EP, decoding, e.g. 6/384 experts, EP16, bsz 256, or EP32, bsz128
return {2, 128, 32, 2};
} else {
// Prefill, or large EP decoding
return {2, 192, 32, 2};
}
}();
// Check whether our `block_m` lies in `kCandidateBlockM`
DG_HOST_ASSERT(std::any_of(
layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs,
[=](const auto& candidate) { return candidate == block_m; })
);
// Return configs
return {cluster_size, block_m, store_block_m, num_epilogue_warpgroups * 128};
}
static int get_num_experts_per_wave_for_mega_moe(
const int& num_experts_per_rank, const int& num_tokens, const int& num_topk,
const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) {
float expected_tokens_per_expert = static_cast<float>(num_tokens) * num_topk / num_experts_per_rank;
if (expected_tokens_per_expert < 1) {
// Most experts don't have tokens, calculate all experts at once
return num_experts_per_rank;
}
// Reduce per-expert block count by this factor since uneven routing leaves some experts with fewer tokens
constexpr int kImbalanceFactor = 2;
// TODO: support num_experts_per_rank > 32
// Find the largest divisor of num_experts_per_rank that fits in 32 as the upper bound
int max_num_experts_per_wave = std::min(32, num_experts_per_rank);
while (max_num_experts_per_wave > 1 and num_experts_per_rank % max_num_experts_per_wave != 0)
-- max_num_experts_per_wave;
// Count L1 blocks per expert assuming tokens are evenly spread across experts
const int expected_tokens_per_expert =
num_tokens * num_topk / num_experts_per_rank + 1;
const int num_m_blocks = ceil_div(expected_tokens_per_expert, block_m);
const int num_n_blocks = intermediate_hidden / block_n;
const int num_m_blocks = ceil_div(static_cast<int>(std::ceil(expected_tokens_per_expert)), block_m);
const int num_n_blocks = (2 * intermediate_hidden) / block_n;
const int num_l1_blocks_per_expert = num_m_blocks * num_n_blocks;
// Pick the smallest value whose total blocks (after imbalance reduction) can keep all SMs busy
int num_experts_per_wave = num_l1_blocks_per_expert > 0
? ceil_div(kImbalanceFactor * num_sms, num_l1_blocks_per_expert) : 1;
num_experts_per_wave = std::min(num_experts_per_wave, max_num_experts_per_wave);
num_experts_per_wave = std::min(num_experts_per_wave, num_experts_per_rank);
// Round up to the nearest divisor of num_experts_per_rank so every wave processes the same count
while (num_experts_per_wave < max_num_experts_per_wave and num_experts_per_rank % num_experts_per_wave != 0)
while (num_experts_per_wave < num_experts_per_rank and num_experts_per_rank % num_experts_per_wave != 0)
++ num_experts_per_wave;
return num_experts_per_wave;
@@ -148,18 +178,18 @@ static std::pair<int, int> get_pipeline_config_for_mega_moe(
static MegaMoEConfig get_mega_moe_config(
const int& num_ranks, const int& num_experts, const int& num_experts_per_rank,
const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk,
const int& hidden, const int& intermediate_hidden) {
// Block tiling
const int block_m = get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
const int& hidden, const int& intermediate_hidden,
const int& num_padded_sf_pool_tokens) {
// Block config
const auto [cluster_size, block_m, store_block_m, num_epilogue_threads] =
get_block_config_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens);
const int block_n = 128;
const int block_k = 128;
const int load_block_m = block_m / 2;
const int load_block_n = block_n;
const int store_block_m = 32;
const auto [sf_block_m, sf_block_n] = SM100ArchSpec::get_sf_uttcp_aligned_block_sizes(block_m, block_n, MmaKind::MXFP8FP4);
const int num_max_pool_tokens = layout::get_num_max_pool_tokens(
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank, block_m);
const int num_padded_sf_pool_tokens = layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m);
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
// NOTES: FP8 activations and FP4 weights (unpacked to 8-bit in smem) both use 128B swizzle
const int swizzle_acts_mode = 128;
const int swizzle_weights_mode = 128;
@@ -173,7 +203,6 @@ static MegaMoEConfig get_mega_moe_config(
// Thread layout
const int num_dispatch_threads = 128;
const int num_non_epilogue_threads = 128;
const int num_epilogue_threads = 256;
// Pipeline
const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe(

View File

@@ -29,6 +29,7 @@ public:
// Runtime arguments
void* y;
int* cumulative_local_expert_recv_stats;
int num_tokens;
layout::SymBuffer<> sym_buffer_ptrs;
@@ -91,6 +92,7 @@ static void __instantiate_kernel() {{
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.y,
args.cumulative_local_expert_recv_stats,
args.num_tokens,
args.sym_buffer_ptrs,
args.tensor_map_l1_acts,
@@ -112,6 +114,7 @@ static void sm100_fp8_fp4_mega_moe(
const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf,
const torch::Tensor& l1_weights, const torch::Tensor& l2_weights,
const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf,
const std::optional<torch::Tensor> cumulative_local_expert_recv_stats,
const std::vector<int64_t>& sym_buffer_ptrs,
const int& rank_idx, const int& num_max_tokens_per_rank,
const int& num_experts_per_rank,
@@ -122,11 +125,12 @@ static void sm100_fp8_fp4_mega_moe(
) {
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
const auto num_experts = num_experts_per_rank * num_ranks;
const auto num_padded_sf_pool_tokens = static_cast<int>(l1_acts_sf.size(0));
// Heuristics
const auto config = get_mega_moe_config(
num_ranks, num_experts, num_experts_per_rank,
num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden);
num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden, num_padded_sf_pool_tokens);
// Make tensormap
constexpr int kGranK = 32;
@@ -175,6 +179,11 @@ static void sm100_fp8_fp4_mega_moe(
config.block_n, kGranK,
num_experts_per_rank, 0);
// Stats can be optional
int* cumulative_local_expert_recv_stats_ptr = nullptr;
if (cumulative_local_expert_recv_stats.has_value())
cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr<int>();
// Launch
const auto num_sms = device_runtime->get_num_sms();
const SM100FP8FP4MegaMoERuntime::Args args = {
@@ -186,6 +195,7 @@ static void sm100_fp8_fp4_mega_moe(
.fast_math = fast_math,
.config = config,
.y = y.data_ptr(),
.cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr,
.num_tokens = num_tokens,
.sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx),
.tensor_map_l1_acts = tensor_map_l1_acts,

View File

@@ -14,11 +14,13 @@ public:
int aligned_batch_size;
int split_kv;
int num_sms;
bool is_varlen;
int batch_size;
int next_n;
bool is_context_lens_2d;
int* context_lens;
int* indices;
int* schedule_metadata;
LaunchArgs launch_args;
@@ -32,10 +34,10 @@ using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sched::smxx_paged_mqa_logits_metadata<
{}, {}, {}
{}, {}, {}, {}
>);
}};
)", args.aligned_batch_size, args.split_kv, args.num_sms);
)", args.aligned_batch_size, args.split_kv, args.num_sms, args.is_varlen ? "true" : "false");
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
@@ -44,6 +46,7 @@ static void __instantiate_kernel() {{
args.next_n,
args.is_context_lens_2d,
args.context_lens,
args.indices,
args.schedule_metadata
));
}
@@ -53,14 +56,15 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
const torch::Tensor& schedule_metadata,
const int& batch_size, const int& next_n,
const int& block_kv, const int& num_sms,
const bool& is_context_lens_2d) {
const bool& is_context_lens_2d,
const bool& is_varlen, const int* indices_ptr) {
constexpr int split_kv = 256;
constexpr int num_threads = 32;
const int aligned_batch_size = align(batch_size, 32);
DG_HOST_ASSERT(split_kv % block_kv == 0);
// Calculate shared memory size
const int smem_size = aligned_batch_size * static_cast<int>(sizeof(int));
// Shared memory: prefix_sum[kAlignedBatchSize] + varlen_atom_token_start/context_len[kAlignedBatchSize] + varlen_num_atoms
const int smem_size = (3 * aligned_batch_size + 1) * static_cast<int>(sizeof(int));
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
@@ -69,10 +73,12 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
.aligned_batch_size = aligned_batch_size,
.split_kv = split_kv,
.num_sms = num_sms,
.is_varlen = is_varlen,
.batch_size = batch_size,
.next_n = next_n,
.is_context_lens_2d = is_context_lens_2d,
.context_lens = context_lens.data_ptr<int>(),
.indices = const_cast<int*>(indices_ptr),
.schedule_metadata = schedule_metadata.data_ptr<int>(),
.launch_args = LaunchArgs(1, num_threads, smem_size)
};
@@ -90,6 +96,7 @@ public:
int head_dim;
int block_kv;
bool is_context_lens_2d;
bool is_varlen;
int block_table_stride;
int logits_stride;
@@ -100,6 +107,7 @@ public:
int* context_lens;
void* logits;
int* block_table;
int* indices;
int* schedule_meta;
CUtensorMap tensor_map_q;
@@ -129,7 +137,7 @@ static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_paged_mqa_logits<
{}, {},
{}, {},
{},
{}, {},
{}, {},
{},
{}, {},
@@ -139,7 +147,7 @@ static void __instantiate_kernel() {{
)", arch, arch,
args.next_n, args.num_heads,
args.head_dim, args.block_kv,
args.is_context_lens_2d,
args.is_context_lens_2d, args.is_varlen ? "true" : "false",
args.num_q_stages, args.num_kv_stages,
args.split_kv,
args.num_specialized_threads, args.num_math_threads,
@@ -151,7 +159,7 @@ static void __instantiate_kernel() {{
args.batch_size,
args.logits_stride, args.block_table_stride,
args.context_lens, args.logits,
args.block_table, args.schedule_meta,
args.block_table, args.indices, args.schedule_meta,
args.tensor_map_q, args.tensor_map_kv,
args.tensor_map_kv_scales, args.tensor_map_weights
));
@@ -165,12 +173,14 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& context_lens,
const torch::Tensor& logits,
const torch::Tensor& block_table,
const torch::Tensor& indices,
const torch::Tensor& schedule_meta,
const at::ScalarType& logits_dtype,
const int& batch_size, const int& next_n,
const int& num_heads, const int& head_dim,
const int& num_kv_blocks, const int& block_kv,
const bool& is_context_lens_2d,
const bool& is_varlen,
const int& logits_stride,
const int& block_table_stride,
const int& num_sms,
@@ -183,7 +193,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0);
// Construct TMAs
const int next_n_atom = (next_n % 2 == 0) ? 2 : 1;
const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1;
const auto tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads,
head_dim, next_n_atom * num_heads,
static_cast<int>(q.stride(2)),
@@ -245,6 +255,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
.head_dim = head_dim,
.block_kv = block_kv,
.is_context_lens_2d = is_context_lens_2d,
.is_varlen = is_varlen,
.block_table_stride = block_table_stride,
.logits_stride = logits_stride,
.num_q_stages = num_q_stages,
@@ -253,6 +264,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
.context_lens = context_lens.data_ptr<int>(),
.logits = logits.data_ptr(),
.block_table = block_table.data_ptr<int>(),
.indices = is_varlen ? indices.data_ptr<int>() : nullptr,
.schedule_meta = schedule_meta.data_ptr<int>(),
.tensor_map_q = tensor_map_q,
.tensor_map_kv = tensor_map_kv,
@@ -279,6 +291,7 @@ public:
int head_dim;
int block_kv;
bool is_context_lens_2d;
bool is_varlen;
int block_table_stride;
int logits_stride;
@@ -289,6 +302,7 @@ public:
int* context_lens;
void* logits;
int* block_table;
int* indices;
int* schedule_meta;
CUtensorMap tensor_map_q;
@@ -314,7 +328,7 @@ static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp4_paged_mqa_logits<
{}, {},
{}, {},
{},
{}, {},
{}, {},
{},
{}, {},
@@ -323,7 +337,7 @@ static void __instantiate_kernel() {{
}};
)", args.next_n, args.num_heads,
args.head_dim, args.block_kv,
args.is_context_lens_2d,
args.is_context_lens_2d, args.is_varlen ? "true" : "false",
args.num_q_stages, args.num_kv_stages,
args.split_kv,
args.num_specialized_threads, args.num_math_threads,
@@ -335,7 +349,7 @@ static void __instantiate_kernel() {{
args.batch_size,
args.logits_stride, args.block_table_stride,
args.context_lens, args.logits,
args.block_table, args.schedule_meta,
args.block_table, args.indices, args.schedule_meta,
args.tensor_map_q, args.tensor_map_sf_q,
args.tensor_map_kv, args.tensor_map_sf_kv,
args.tensor_map_weights
@@ -351,12 +365,14 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& context_lens,
const torch::Tensor& logits,
const torch::Tensor& block_table,
const torch::Tensor& indices,
const torch::Tensor& schedule_meta,
const at::ScalarType& logits_dtype,
const int& batch_size, const int& next_n,
const int& num_heads, const int& head_dim,
const int& num_kv_blocks, const int& block_kv,
const bool& is_context_lens_2d,
const bool& is_varlen,
const int& logits_stride,
const int& block_table_stride,
const int& num_sms,
@@ -366,8 +382,8 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q,
DG_HOST_ASSERT(split_kv == 256 and logits_stride % split_kv == 0);
// TODO: tuning num_stages
const int num_q_stages = 3, num_kv_stages = 6, num_tmem_stages = 3;
const int next_n_atom = (next_n % 2 == 0) ? 2 : 1;
const int num_q_stages = 3, num_kv_stages = 10, num_tmem_stages = 3;
const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1;
// `head_dim` must be 128 for 64B swizzling
DG_HOST_ASSERT(head_dim == 128);
@@ -416,6 +432,7 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q,
.head_dim = head_dim,
.block_kv = block_kv,
.is_context_lens_2d = is_context_lens_2d,
.is_varlen = is_varlen,
.block_table_stride = block_table_stride,
.logits_stride = logits_stride,
.num_q_stages = num_q_stages,
@@ -424,6 +441,7 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q,
.context_lens = context_lens.data_ptr<int>(),
.logits = logits.data_ptr(),
.block_table = block_table.data_ptr<int>(),
.indices = is_varlen ? indices.data_ptr<int>() : nullptr,
.schedule_meta = schedule_meta.data_ptr<int>(),
.tensor_map_q = tensor_map_q,
.tensor_map_sf_q = tensor_map_sf_q,

View File

@@ -123,4 +123,4 @@ _C.init(
_find_cuda_home() # CUDA home
)
__version__ = '2.4.2'
__version__ = '2.5.0'

View File

@@ -1,11 +1,20 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/layout/sym_buffer.cuh>
#include <deep_gemm/layout/mega_moe.cuh>
namespace deep_gemm::comm {
CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() {
// Perform cluster_sync with `barrier.cluster.arrive.relaxed`
// This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee
cute::cluster_arrive_relaxed();
cute::cluster_wait();
}
template <uint32_t kNumSMs, uint32_t kGridSyncIndex = 0, typename sync_scope_t>
CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace,
const uint32_t& sm_idx, const uint32_t& thread_idx,

View File

@@ -401,7 +401,8 @@ void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
// Load accumulator from TMEM
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
tmem_load(cute::Int<kNumHeads>{}, tmem_addr, accum);
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
// Release TMEM empty
if (i == BLOCK_Q - 1) {

View File

@@ -20,7 +20,7 @@ namespace deep_gemm {
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D,
bool kIsContextLens2D, bool kIsVarlen,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
@@ -30,7 +30,8 @@ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* schedule_meta,
const uint32_t* block_table, const uint32_t* indices,
const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
@@ -54,10 +55,10 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
}
// Next-N atom configs
static constexpr uint32_t kNextNAtom = (kNextN % 2 == 0) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = kNextN / kNextNAtom;
static constexpr bool kSingleAtom = (kNumNextNAtoms == 1);
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
// UMMA configs
static constexpr uint32_t kNumTmemStages = 3;
@@ -157,7 +158,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
// Scheduler
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
// Make Q, KV and TMEM pipeline
@@ -182,7 +183,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
if (cute::elect_one_sync()) {
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
// Persistently schedule over blocks
// Initialize outside valid range to indicate no previous task
@@ -196,11 +197,12 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
// Issue TMA Q
const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx);
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_q[q_stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_atom_idx * kNextNAtom);
tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_atom_idx * kNextNAtom);
smem_q[q_stage_idx], 0, q_token_idx * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx);
tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx);
full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
}
last_q_atom_idx = q_atom_idx;
@@ -210,7 +212,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
} else if (warp_idx == kSpecWarpStart + 1) {
// TMA warp for loading KV cache
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
// Persistently schedule over blocks
uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage;
@@ -225,10 +227,11 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
// Coalesced load of block table
if (kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast<uint64_t>(block_table_stride);
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
}
__syncwarp();
// Broadcast KV block indices
int kv_block_idx[kNumBlocksPerSplit];
@@ -240,7 +243,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
// Wait KV consumer release
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
// Issue TMA KV
if (cute::elect_one_sync()) {
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
@@ -260,7 +263,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
} else if (warp_idx == kSpecWarpStart + 2) {
// UMMA warp
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// UTCCP transposer
@@ -371,7 +374,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
} else if (warp_idx < kSpecWarpStart) {
// Math warpgroups for reduce
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
const auto math_warpgroup_idx = warpgroup_idx;
const auto math_thread_idx = warp_idx * 32 + lane_idx;
@@ -400,6 +403,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
// Persistently schedule over blocks
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
uint32_t q_atom_idx, kv_idx, _;
bool is_paired_atom = false;
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
if (q_atom_idx != last_q_atom_idx) {
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
@@ -423,11 +427,16 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
weights[i][j + 3] = raw.w;
}
}
// Check if this atom pairs two tokens from the same sequence
if constexpr (kIsVarlen) {
is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2);
}
}
last_q_atom_idx = q_atom_idx;
// Calculate KV offset in advance
auto kv_offset = q_atom_idx * kNextNAtom * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
// Advance pipeline by `kNumMathWarpGroups` steps
// Wait UMMA arrival
@@ -436,53 +445,58 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
ptx::tcgen05_after_thread_sync();
// Reduce over the head dim and store
#pragma unroll
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
// Load accumulator from TMEM
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
tmem_load(cute::Int<kNumHeads>{}, tmem_addr, accum);
const auto reduce_and_store = [&](auto num_iters_c) {
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
// Only loop over valid iterations
#pragma unroll
for (uint32_t i = 0; i < kNumIters; ++ i) {
// Load accumulator from TMEM
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
// Store into the global memory
logits[kv_offset + i * static_cast<uint64_t>(logits_stride)] = result;
__syncwarp();
}
// Release TMEM empty
if (i == kNextNAtom - 1) {
ptx::tcgen05_before_thread_sync();
empty_tmem_barriers[tmem_stage_idx]->arrive();
}
ptx::tcgen05_before_thread_sync();
empty_tmem_barriers[tmem_stage_idx]->arrive();
};
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
// Store into the global memory
const auto dst_offset = kv_offset + i * static_cast<uint64_t>(logits_stride);
if constexpr(sizeof(logits_dtype_t) == 2) {
// Pack two adjacent bf16 lanes into uint32 for wider store
uint16_t my_bits = *reinterpret_cast<const uint16_t*>(&result);
uint16_t neighbor_bits = __shfl_down_sync(0xffffffff, my_bits, 1);
uint32_t packed;
asm volatile("mov.b32 %0, {%1, %2};" : "=r"(packed) : "h"(my_bits), "h"(neighbor_bits));
if (lane_idx % 2 == 0)
*reinterpret_cast<uint32_t*>(logits + dst_offset) = packed;
} else {
logits[dst_offset] = result;
}
// this sync warp prevent the next load tmem from reordering
// nvcc may reorder it to overlap with the current tmem load, lead to large register usage
__syncwarp();
if constexpr (kIsVarlen) {
if (is_paired_atom)
reduce_and_store(cute::Int<kNextNAtom>{});
else
reduce_and_store(cute::Int<1>{});
} else if constexpr (kPadOddN) {
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
reduce_and_store(cute::Int<1>{});
else
reduce_and_store(cute::Int<kNextNAtom>{});
} else {
reduce_and_store(cute::Int<kNextNAtom>{});
}
}

View File

@@ -48,6 +48,7 @@ template <
>
CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void
sm100_fp8_fp4_mega_moe_impl(void* y,
int* cumulative_local_expert_recv_stats,
const uint32_t num_tokens,
const __grid_constant__ layout::SymBuffer<kNumRanks> sym_buffer,
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts,
@@ -91,7 +92,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Workspaces
const auto workspace = layout::Workspace(
sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk, BLOCK_M);
sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk);
// Token and buffer layouts
constexpr auto fp8_token_layout = layout::Data(kHidden);
@@ -170,7 +171,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
constexpr uint32_t UMMA_K = 32;
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N;
DG_STATIC_ASSERT(BLOCK_M % 32 == 0, "Invalid block M");
DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M");
DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N");
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
@@ -269,7 +270,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2);
// A cluster sync is essential for 2CTA tensor memory allocation
cute::cluster_sync();
comm::cluster_sync_with_relaxed_arrive();
// Initialization
if (warp_idx == 0) {
@@ -307,7 +308,9 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
cute::cluster_sync();
// NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`,
// and `barrier.cluster.wait.aligned` is by default `.acquire`
comm::cluster_sync_with_relaxed_arrive();
// Task scheduler
auto scheduler = sched::MegaMoEScheduler<
@@ -599,7 +602,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
__syncwarp();
}
// Clean workspace for the next usage
// Clean workspace for the next usage, and also do cumulative stats
// NOTES: it is overlapped with combine reduction epilogue
ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
@@ -623,19 +626,27 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Wait read count ready
ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
// Clean expert token count
if (thread_idx == 0)
// Clean expert token count, and add cumulative results
DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps");
if (warp_idx == 0) {
*workspace.get_expert_recv_count_sum_ptr(i) = 0;
} else if (warp_idx == 1) {
if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr)
ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast<int>(num_recv_tokens));
__syncwarp();
}
// Clean per-rank token count
for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads)
*workspace.get_expert_recv_count_ptr(j, i) = 0;
__syncwarp();
// Clean L1 and L2 arrival stuffs
for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) {
*workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0;
*workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0;
}
__syncwarp();
}
}
@@ -672,23 +683,22 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx);
const auto expected = scheduler.template get_valid_m<false>();
while (ptx::ld_acq(ptr) != expected);
} else {
// The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival
// NOTES: Originally we wait blocks on-demand to overlap L1 calculation
// with L2, but this optimization is negative when `num_experts_per_wave`
// guarantees L1's completion when L2 starts. So we remove it.
// In the future, if `num_experts_per_wave` is not large enough
// due to small `num_experts_per_rank`, we may need to add it back or add a switch
DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes");
const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx);
// NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts
// to avoid undefined behavior when `num_k_blocks == 32`
const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1;
while (ptx::ld_acq_gpu(ptr) != expected);
}
uint64_t cached_l2_arrival_mask = 0;
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) {
// Wait current K block arrival
if (block_phase == sched::BlockPhase::Linear2) {
// The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2 L1 blocks' arrival
DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes");
const uint64_t needed = 3ull << (k_block_idx * 2);
if ((cached_l2_arrival_mask & needed) != needed) {
const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx);
do {
cached_l2_arrival_mask = ptx::ld_acq_gpu(ptr);
} while ((cached_l2_arrival_mask & needed) != needed);
}
}
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
@@ -953,8 +963,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Load weights from global into register cache per 32 tokens
DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size");
DG_STATIC_ASSERT(WG_BLOCK_M % 32 == 0, "Invalid block size");
if ((j * ATOM_M) % 32 == 0) {
if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) {
stored_cached_weight = *l1_topk_weights_buffer
.get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx)
.get_base_ptr<float>();
@@ -1060,19 +1069,26 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Only one warp per pair writes (both hold the same SF after cross-warp reduce)
// Each lane < 4 holds SF for 2 rows (sf.x and sf.y)
if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {
// TODO: I believe the expression can be optimized
const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M
+ epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2;
const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2;
const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4;
const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t);
const auto sf_base_ptr = l2_sf_buffer.get_base_ptr<uint8_t>();
// NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4
// NOTES: originally there was:
// - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2
// - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)`
// We find out that
// 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside
// 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside
// This reduce the number of computation instructions.
const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
__builtin_assume(token_base_idx < BLOCK_M);
const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
+ transform_sf_token_idx(token_idx_in_expert);
sf_base_ptr[k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx] =
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx;
sf_base_ptr[sf_addr] =
(*reinterpret_cast<const uint32_t*>(&sf.x) >> 23);
sf_base_ptr[k_uint_idx * mn_stride + (sf_pool_token_idx + 4) * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx] =
sf_base_ptr[sf_addr + 4 * static_cast<uint32_t>(sizeof(uint32_t))] =
(*reinterpret_cast<const uint32_t*>(&sf.y) >> 23);
}
__syncwarp();

View File

@@ -20,7 +20,7 @@ namespace deep_gemm {
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D,
bool kIsContextLens2D, bool kIsVarlen,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
@@ -30,7 +30,8 @@ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* schedule_meta,
const uint32_t* block_table, const uint32_t* indices,
const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
@@ -53,10 +54,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
cute::prefetch_tma_descriptor(&tensor_map_weights);
}
// Next-N atom configs
static constexpr uint32_t kNextNAtom = (kNextN % 2 == 0) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = kNextN / kNextNAtom;
static constexpr bool kSingleAtom = (kNumNextNAtoms == 1);
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
@@ -136,7 +137,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Scheduler
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
// Q and KV pipeline
@@ -157,13 +158,14 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
if (warp_idx == kSpecWarpStart) {
// TMA warp for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_atom_idx) {
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) {
if (cute::elect_one_sync()) {
tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_atom_idx * kNextNAtom);
const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx);
tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
}
};
@@ -182,7 +184,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
while (fetched_next_task) {
// Prefetch next Q when (q, atom) changes
bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + 1);
const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size);
bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance);
if (q_atom_idx != next_q_atom_idx)
kv_block_idx_ptr = 32;
@@ -195,17 +198,18 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// TODO(xuzhean): consider -1
if (kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast<uint64_t>(block_table_stride);
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
}
__syncwarp();
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
// Wait Q consumer release and issue TMA Q
if (prefetch_q) {
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
issue_tma_q(q_stage_idx, q_atom_idx + 1);
issue_tma_q(q_stage_idx, q_atom_idx + next_advance);
}
uint32_t kv_block_idx[kNumBlocksPerSplit];
@@ -236,7 +240,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
}
} else if (warp_idx == kSpecWarpStart + 1) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
// Require full allocation
@@ -292,7 +296,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
} else if (warp_idx < kSpecWarpStart) {
// Math warpgroups for reduce
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
// Offsets
@@ -321,6 +325,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
uint32_t q_stage_idx, q_phase;
uint32_t umma_phase = 0;
bool is_paired_atom = false;
while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
// Q or atom changes
@@ -340,6 +345,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
for (uint32_t j = 0; j < kNumHeads; ++ j)
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
if constexpr (kIsVarlen) {
is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2);
}
}
// Get current task indices
@@ -347,7 +356,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
kv_idx = next_kv_idx;
// Calculate KV offset in advance
auto kv_offset = q_atom_idx * kNextNAtom * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
// Wait TMA KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
@@ -367,40 +376,56 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Reduce over the head dim and store
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
#pragma unroll
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
// Load accumulator from TMEM
const auto reduce_and_store = [&](auto num_iters_c) {
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
float accum[kNumHeads];
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
// Release TMEM empty
if (i == kNextNAtom - 1) {
ptx::tcgen05_before_thread_sync();
empty_umma_barriers[math_warpgroup_idx]->arrive();
}
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
for (uint32_t i = 0; i < kNumIters; ++ i) {
// Load accumulator from TMEM
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
// Store into the global memory
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
__syncwarp();
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
// Release TMEM empty
ptx::tcgen05_before_thread_sync();
empty_umma_barriers[math_warpgroup_idx]->arrive();
};
// Store into the global memory
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
__syncwarp();
if constexpr (kIsVarlen) {
if (is_paired_atom)
reduce_and_store(cute::Int<kNextNAtom>{});
else
reduce_and_store(cute::Int<1>{});
} else if constexpr (kPadOddN) {
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
reduce_and_store(cute::Int<1>{});
else
reduce_and_store(cute::Int<kNextNAtom>{});
} else {
reduce_and_store(cute::Int<kNextNAtom>{});
}
}

View File

@@ -21,7 +21,7 @@ namespace deep_gemm {
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D,
bool kIsContextLens2D, bool kIsVarlen,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
@@ -30,11 +30,14 @@ CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* schedule_meta,
const uint32_t* block_table, const uint32_t* indices,
const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits");
// Types
using WGMMA = typename mma::sm90::FP8MMASelector<kNextN * kNumHeads>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
@@ -132,8 +135,8 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
cudaGridDependencySynchronize();
// Scheduler
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups, 1>(
blockIdx.x, context_lens, schedule_meta);
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumMathWarpGroups, 1>(
blockIdx.x, batch_size, context_lens, schedule_meta, indices);
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
// Q and KV pipeline

View File

@@ -1,19 +1,27 @@
#pragma once
#include <cute/numeric/math.hpp>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::layout {
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding
static constexpr int kNumCandidateBlockMs = 7;
static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192};
static constexpr int kMaxCandidateBlockM = 192;
static constexpr int kMinCandidateBlockM = 8;
static constexpr int kLCMCandidateBlockM = 384;
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M
template <typename T>
CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk,
T num_experts_per_rank, T block_m) {
T num_experts_per_rank) {
const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank;
const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank);
return math::constexpr_align(
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (block_m - 1),
block_m);
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast<T>(kMaxCandidateBlockM) - 1),
static_cast<T>(kLCMCandidateBlockM));
}
// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M
@@ -48,17 +56,14 @@ struct Workspace {
const uint32_t& num_ranks,
const uint32_t& num_experts,
const uint32_t& num_max_tokens_per_rank,
const uint32_t& num_topk,
const uint32_t& block_m):
const uint32_t& num_topk):
base(base),
num_ranks(num_ranks), num_experts(num_experts),
num_max_tokens_per_rank(num_max_tokens_per_rank) {
num_experts_per_rank = num_experts / num_ranks;
num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank;
num_max_pool_tokens = get_num_max_pool_tokens(
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank, block_m);
num_max_pool_blocks = num_max_pool_tokens / block_m;
DG_UNIFIED_ASSERT(num_max_tokens_per_rank % block_m == 0);
num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM;
}
CUTLASS_HOST_DEVICE

View File

@@ -164,7 +164,7 @@ CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) {
}
CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) {
asm volatile("st.L1::no_allocate.relaxed.sys.u64 [%0], %1;" :: "l"(ptr), "l"(value));
asm volatile("st.L1::no_allocate.relaxed.sys.global.u64 [%0], %1;" :: "l"(ptr), "l"(value));
}
/// Atomics
@@ -186,7 +186,11 @@ CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& valu
return ret;
}
__forceinline__ __device__ void red_add(const uint32_t* ptr, const uint32_t& value) {
CUTLASS_DEVICE void red_add(const int* ptr, const int& value) {
asm volatile("red.gpu.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value));
}
CUTLASS_DEVICE void red_add(const uint32_t* ptr, const uint32_t& value) {
asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value));
}

View File

@@ -6,22 +6,51 @@
namespace deep_gemm::sched {
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs, bool kIsVarlen = false>
CUTLASS_GLOBAL __launch_bounds__(32, 1)
void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d,
const uint32_t* context_lens, uint32_t* schedule_metadata) {
const uint32_t* context_lens, const uint32_t* indices, uint32_t* schedule_metadata) {
DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
const uint32_t lane_idx = ptx::get_lane_idx();
// Wait for primary kernel completion
cudaGridDependencySynchronize();
__shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize];
__shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize];
__shared__ uint32_t varlen_num_atoms_shared;
uint32_t num_items;
if constexpr (kIsVarlen) {
if (lane_idx == 0) {
uint32_t t = 0, atom_count = 0;
while (t < batch_size) {
varlen_atom_token_start[atom_count] = t;
const bool is_paired = (t + 1 < batch_size and indices[t] == indices[t + 1]);
varlen_atom_context_len[atom_count] = is_paired ? context_lens[t + 1] : context_lens[t];
t += is_paired ? 2 : 1;
++ atom_count;
}
varlen_num_atoms_shared = atom_count;
}
__syncwarp();
num_items = varlen_num_atoms_shared;
} else {
num_items = batch_size;
}
// Compute num_segs and prefix sum
uint32_t num_segs[kAlignedBatchSize / 32];
#pragma unroll
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
const uint32_t q_idx = k * 32 + lane_idx;
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
const uint32_t context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0);
uint32_t context_len;
if constexpr (kIsVarlen) {
context_len = (q_idx < num_items ? varlen_atom_context_len[q_idx] : 0);
} else {
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0);
}
num_segs[k] = math::ceil_div(context_len, SPLIT_KV);
}
@@ -40,44 +69,118 @@ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t ne
sum = __shfl_sync(0xffffffff, x, 31);
}
const uint32_t num_next_n_atoms = next_n / ((next_n % 2 == 0) ? 2 : 1);
const uint32_t total = sum * num_next_n_atoms;
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
uint32_t q_idx = 0;
while (q_idx < batch_size and prefix_sum[q_idx] * num_next_n_atoms <= seg_starts)
++ q_idx;
const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms);
const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]);
const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0;
const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0;
const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx;
__syncwarp();
// SM work distribution
if constexpr (kIsVarlen) {
const uint32_t total = sum;
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
uint32_t lo = 0, hi = num_items;
while (lo < hi) {
const uint32_t mid = (lo + hi) / 2;
const bool pred = prefix_sum[mid] <= seg_starts;
lo = pred ? mid + 1 : lo;
hi = pred ? hi : mid;
}
const uint32_t atom_idx = lo;
const uint32_t kv_split_idx = (atom_idx == 0 ? seg_starts : seg_starts - prefix_sum[atom_idx - 1]);
const uint32_t q_atom_idx = (atom_idx < num_items ? varlen_atom_token_start[atom_idx] : batch_size);
__syncwarp();
schedule_metadata[sm_idx * 2] = q_atom_idx;
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
schedule_metadata[sm_idx * 2] = q_atom_idx;
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
}
} else {
const uint32_t next_n_atom = (next_n >= 2) ? 2 : 1;
const uint32_t num_next_n_atoms = math::ceil_div(next_n, next_n_atom);
const uint32_t total = sum * num_next_n_atoms;
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
uint32_t lo = 0, hi = batch_size;
while (lo < hi) {
const uint32_t mid = (lo + hi) / 2;
const bool pred = prefix_sum[mid] * num_next_n_atoms <= seg_starts;
lo = pred ? mid + 1 : lo;
hi = pred ? hi : mid;
}
const uint32_t q_idx = lo;
const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms);
const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]);
const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0;
const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0;
const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx;
__syncwarp();
schedule_metadata[sm_idx * 2] = q_atom_idx;
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
}
}
}
template <uint32_t kNextN, bool kIsContextLens2D,
// Conditional storage for varlen indices pointer (EBO: zero cost when unused)
template <bool kHasIndices>
struct IndicesStorage {
const uint32_t* indices;
};
template <>
struct IndicesStorage<false> {};
template <uint32_t kNextN, bool kIsContextLens2D, bool kIsVarlen,
uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit,
uint32_t kNumNextNAtoms>
struct PagedMQALogitsScheduler {
struct PagedMQALogitsScheduler : IndicesStorage<kIsVarlen> {
const uint32_t* context_lens;
uint32_t batch_size;
uint32_t current_q_atom_idx, current_kv_idx;
uint32_t end_q_atom_idx, end_kv_idx;
uint32_t current_num_kv;
CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const {
const uint32_t q_idx = q_atom_idx / kNumNextNAtoms;
const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
return math::ceil_div(context_lens[lens_idx], BLOCK_KV);
CUTLASS_DEVICE static uint32_t atom_to_token_idx(const uint32_t& q_atom_idx) {
if constexpr (kIsVarlen) {
return q_atom_idx;
} else {
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
if constexpr (kPadOddN) {
return q_atom_idx / kNumNextNAtoms * kNextN + q_atom_idx % kNumNextNAtoms * kNextNAtom;
} else {
return q_atom_idx * kNextNAtom;
}
}
}
CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t* context_lens, const uint32_t* schedule_meta) {
CUTLASS_DEVICE static uint32_t atom_to_block_table_row(const uint32_t& q_atom_idx) {
if constexpr (kIsVarlen) {
return q_atom_idx;
} else {
return q_atom_idx / kNumNextNAtoms;
}
}
CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const {
if constexpr (kIsVarlen) {
const bool is_paired = (q_atom_idx + 1 < batch_size and
this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]);
const uint32_t ctx_len = is_paired ? context_lens[q_atom_idx + 1] : context_lens[q_atom_idx];
return math::ceil_div(ctx_len, BLOCK_KV);
} else {
const uint32_t q_idx = q_atom_idx / kNumNextNAtoms;
const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
return math::ceil_div(context_lens[lens_idx], BLOCK_KV);
}
}
CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t& batch_size,
const uint32_t* context_lens,
const uint32_t* schedule_meta, const uint32_t* indices) {
this->context_lens = context_lens;
this->batch_size = batch_size;
if constexpr (kIsVarlen) {
this->indices = indices;
}
const auto current_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx];
const auto end_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx + 1];
@@ -87,6 +190,28 @@ struct PagedMQALogitsScheduler {
current_num_kv = get_num_kv(current_q_atom_idx);
}
// Advance step in q_atom_idx space when moving to the next atom.
// Varlen: 1 or 2 depending on whether consecutive tokens share the same sequence.
// Non-varlen: always 1 (one atom unit).
CUTLASS_DEVICE uint32_t get_atom_advance(const uint32_t& q_atom_idx, const uint32_t& bound) const {
if constexpr (kIsVarlen) {
return (q_atom_idx + 1 < bound and this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]) ? 2 : 1;
} else {
return 1;
}
}
// Whether num_kv should be refreshed after advancing to q_atom_idx.
// Varlen: always refresh (each atom may have a different context_len).
// Non-varlen: only at atom-group boundaries (atoms within a group share context_len).
CUTLASS_DEVICE bool should_refresh_num_kv(const uint32_t& q_atom_idx) const {
if constexpr (kIsVarlen) {
return true;
} else {
return q_atom_idx % kNumNextNAtoms == 0;
}
}
CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) {
q_atom_idx = current_q_atom_idx;
kv_idx = current_kv_idx;
@@ -97,9 +222,9 @@ struct PagedMQALogitsScheduler {
current_kv_idx += kNumBlocksPerSplit;
if (current_kv_idx >= current_num_kv) {
++ current_q_atom_idx;
current_kv_idx = 0;
if (current_q_atom_idx % kNumNextNAtoms == 0 and exist_q_atom_idx(current_q_atom_idx)) {
current_q_atom_idx += get_atom_advance(current_q_atom_idx, end_q_atom_idx);
if (should_refresh_num_kv(current_q_atom_idx) and exist_q_atom_idx(current_q_atom_idx)) {
current_num_kv = get_num_kv(current_q_atom_idx);
}
}

View File

@@ -61,10 +61,8 @@ def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu') -> SymmBuffer:
# Token count must be aligned to block m
num_ranks = group.size()
block_m = _C.get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk)
num_max_tokens_per_rank = align(num_max_tokens_per_rank, block_m)
# Token count must be aligned to block sizes
num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe())
return SymmBuffer(
group, num_experts,
@@ -111,6 +109,7 @@ def fp8_fp4_mega_moe(y: torch.Tensor,
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],
sym_buffer: SymmBuffer,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
recipe: Tuple[int, int, int] = (1, 1, 32),
activation: str = 'swiglu',
activation_clamp: Optional[float] = None,
@@ -118,6 +117,7 @@ def fp8_fp4_mega_moe(y: torch.Tensor,
_C.fp8_fp4_mega_moe(
y,
l1_weights, l2_weights,
cumulative_local_expert_recv_stats,
sym_buffer.buffer,
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
sym_buffer.num_max_tokens_per_rank,

448
scripts/quick_plot_pm.py Normal file
View File

@@ -0,0 +1,448 @@
#!/usr/bin/env python3
"""Plot a curated set of NCU PM metrics from an .ncu-rep report.
Usage:
python scripts/quick_plot_pm.py [report.ncu-rep]
By default the script saves a PNG next to the report.
With --interactive, it opens a Qt window instead.
"""
import argparse
import csv
import io
import subprocess
from dataclasses import dataclass
import matplotlib
import numpy as np
@dataclass(frozen=True)
class MetricSpec:
name: str
metric: str
kind: str
category: str
aliases: tuple[str, ...] = ()
@dataclass(frozen=True)
class ResolvedMetricSpec:
name: str
metric: str
kind: str
category: str
@dataclass(frozen=True)
class MetricSeries:
name: str
metric: str
category: str
unit: str
values: tuple[float, ...]
CATEGORY_ORDER = [
"Overview",
"SM",
"L1",
"L2",
"DRAM",
"Interconnect",
]
KIND_SUFFIXES = {
"pct_peak": [".avg.pct_of_peak_sustained_elapsed"],
"pct": [".pct", ".avg.pct_of_peak_sustained_elapsed"],
"avg": [".avg"],
"sum": [".sum"],
"avg_per_second": [".avg.per_second"],
"sum_per_second": [".sum.per_second"],
"avg_per_cycle_active": [".avg.per_cycle_active"],
"avg_per_cycle_elapsed": [".avg.per_cycle_elapsed"],
"sum_per_cycle_elapsed": [".sum.per_cycle_elapsed"],
}
# Curated from scripts/ncu-metrics.txt, with a few corrections against
# `ncu --query-metrics --chip gb100`:
# - Blocks launched uses `gr__ctas_launched_realtime`
# - SM active cycles uses `sm__cycles_active`
# - L2 throughput for GCC requests uses `lts__t_sector_throughput_srcunit_gcc`
# - C2C throughput uses `ctc__throughput`
# - NVLink RX metrics use the `NVLRX` domain
CURATED_METRICS = [
MetricSpec("Blocks Launched", "FE_B.TriageCompute.gr__ctas_launched_realtime", "sum_per_cycle_elapsed", "Overview"),
MetricSpec("Average Blocks Active", "TPC.TriageCompute.tpc__ctas_active_realtime", "avg_per_cycle_elapsed", "Overview"),
MetricSpec("Total Blocks Active", "TPC.TriageCompute.tpc__ctas_active_realtime", "sum_per_cycle_elapsed", "Overview"),
MetricSpec("Average CGAs Active", "GPC_B.TriageCompute.gpc__cgas_active_realtime", "avg_per_cycle_elapsed", "Overview"),
MetricSpec("Total CGAs Active", "GPC_B.TriageCompute.gpc__cgas_active_realtime", "sum_per_cycle_elapsed", "Overview"),
MetricSpec("SM Active Cycles", "TPC.TriageCompute.sm__cycles_active", "avg", "SM"),
MetricSpec("Executed IPC Active", "TPC.TriageCompute.sm__inst_executed_realtime", "avg_per_cycle_active", "SM"),
MetricSpec("Executed IPC Elapsed", "TPC.TriageCompute.sm__inst_executed_realtime", "avg_per_cycle_elapsed", "SM"),
MetricSpec("SM Throughput", "TPC.TriageCompute.sm__inst_executed_realtime", "pct_peak", "SM"),
MetricSpec("SM ALU Pipe Throughput", "TPC.TriageCompute.sm__inst_executed_pipe_alu_realtime", "pct_peak", "SM"),
MetricSpec("SM FMA Pipe Throughput", "TPC.TriageCompute.sm__pipe_fma_cycles_active_realtime", "pct_peak", "SM"),
MetricSpec("SM FMA Heavy Pipe Throughput", "TPC.TriageCompute.sm__pipe_fmaheavy_cycles_active_realtime", "pct_peak", "SM"),
MetricSpec("SM FMA Light Pipe Throughput", "TPC.TriageCompute.sm__pipe_fmalite_cycles_active_realtime", "pct_peak", "SM"),
MetricSpec("SM Tensor Pipe Throughput", "TPC.TriageCompute.sm__pipe_tensor_cycles_active_realtime", "pct_peak", "SM"),
MetricSpec("SM TMEM Pipe Throughput", "SM_A.TriageCompute.sm__mem_tensor_cycles_active_realtime", "pct_peak", "SM"),
MetricSpec("SM Uniform Pipe Throughput", "SM_A.TriageCompute.sm__inst_executed_pipe_uniform_realtime", "pct_peak", "SM"),
MetricSpec("SM XU Pipe Throughput", "SM_A.TriageCompute.sm__inst_executed_pipe_xu_realtime", "pct_peak", "SM"),
MetricSpec("L1 Throughput", "SM_A.TriageCompute.l1tex__throughput", "pct_peak", "L1"),
MetricSpec("L1 Sectors", "SM_B.TriageCompute.l1tex__t_sectors", "sum", "L1"),
MetricSpec("L1 Hit Rate", "SM_B.TriageCompute.l1tex__t_sector_hit_rate", "pct", "L1"),
MetricSpec("L1 Lookup Hit", "SM_B.TriageCompute.l1tex__t_sectors_lookup_hit", "sum", "L1"),
MetricSpec("L1 Lookup Miss", "SM_B.TriageCompute.l1tex__t_sectors_lookup_miss", "sum", "L1"),
MetricSpec("L1 Wavefronts (Data)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts", "avg", "L1"),
MetricSpec("L1 Wavefronts (LGDS)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts_mem_lgds", "avg", "L1"),
MetricSpec("L1 Wavefronts (Shared)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts_mem_shared", "avg", "L1"),
MetricSpec("L2 Throughput", "LTS.TriageCompute.lts__throughput", "pct_peak", "L2"),
MetricSpec("L2 Throughput for L1 Requests", "LTS.TriageCompute.lts__t_sector_throughput_srcunit_tex", "pct_peak", "L2"),
MetricSpec("L2 Throughput for GCC Requests", "LTS.TriageCompute.lts__t_sector_throughput_srcunit_gcc", "pct_peak", "L2"),
MetricSpec("L2 Throughput to DRAM", "LTS.TriageCompute.lts__t_sector_throughput_srcnode_fbp", "pct_peak", "L2"),
MetricSpec("SysL2 Throughput to Peer Memory", "SYSLTS.TriageCompute.syslts__t_sector_throughput_aperture_peer", "pct_peak", "L2"),
MetricSpec("SysL2 Throughput to System Memory", "SYSLTS.TriageCompute.syslts__t_sector_throughput_aperture_sysmem", "pct_peak", "L2"),
MetricSpec("L2 Hit Rate", "LTS.TriageCompute.lts__average_t_sector_hit_rate_realtime", "pct", "L2"),
MetricSpec("L2 Hit Rate From L1", "LTS.TriageCompute.lts__average_t_sector_hit_rate_srcunit_tex_realtime", "pct", "L2"),
MetricSpec("DRAM Frequency", "FBSP.TriageCompute.dram__cycles_elapsed", "avg_per_second", "DRAM"),
MetricSpec("DRAM Throughput", "FBSP.TriageCompute.dram__throughput", "pct_peak", "DRAM"),
MetricSpec("DRAM Read Throughput", "FBSP.TriageCompute.dram__read_throughput", "pct_peak", "DRAM"),
MetricSpec("DRAM Write Throughput", "FBSP.TriageCompute.dram__write_throughput", "pct_peak", "DRAM"),
MetricSpec("C2C Throughput", "TriageCompute.ctc__throughput", "pct_peak", "Interconnect", aliases=("TriageCompute.ctx__throughput",)),
MetricSpec("NVLink Transmitted Throughput", "NVLTX.TriageCompute.nvltx__bytes", "pct_peak", "Interconnect"),
MetricSpec("NVLink Received Throughput", "NVLRX.TriageCompute.nvlrx__bytes", "pct_peak", "Interconnect"),
MetricSpec("NVLink Transmitted Bandwidth", "NVLTX.TriageCompute.nvltx__bytes", "sum_per_second", "Interconnect"),
MetricSpec("NVLink Received Bandwidth", "NVLRX.TriageCompute.nvlrx__bytes", "sum_per_second", "Interconnect"),
MetricSpec("PCIe Throughput", "PCI.TriageCompute.pcie__throughput", "pct_peak", "Interconnect"),
MetricSpec("PCIe Read Bandwidth", "PCI.TriageCompute.pcie__read_bytes", "sum_per_second", "Interconnect"),
MetricSpec("PCIe Write Bandwidth", "PCI.TriageCompute.pcie__write_bytes", "sum_per_second", "Interconnect"),
]
def _run_csv_command(command, timeout):
result = subprocess.run(command, capture_output=True, text=True, timeout=timeout)
if result.returncode != 0 and not result.stdout:
return None
reader = csv.reader(io.StringIO(result.stdout))
return list(reader)
def _query_available_metrics(chip):
result = subprocess.run(
["ncu", "--query-metrics", "--chip", chip],
capture_output=True,
text=True,
timeout=30,
)
if result.returncode != 0:
raise RuntimeError(result.stderr.strip() or f"failed to query metrics for chip {chip}")
metrics = set()
for line in result.stdout.splitlines():
parts = line.split()
if not parts:
continue
token = parts[0]
if "__" not in token:
continue
metrics.add(token)
return metrics
def _metric_candidates(metric):
candidates = [metric]
marker = ".TriageCompute."
if marker in metric:
candidates.append(metric.split(marker, 1)[1])
return candidates
def resolve_metric_specs(chip):
available = _query_available_metrics(chip)
resolved = []
missing = []
for spec in CURATED_METRICS:
candidates = []
for metric in (spec.metric, *spec.aliases):
candidates.extend(_metric_candidates(metric))
actual_metric = next((metric for metric in candidates if metric in available), None)
if actual_metric is None:
missing.append(spec)
continue
resolved.append(ResolvedMetricSpec(spec.name, actual_metric, spec.kind, spec.category))
return resolved, missing
def _parse_metric_values(raw):
if not raw or raw == "no data":
return ()
try:
if raw.startswith("(") and raw.endswith(")"):
rest = raw[1:-1]
return tuple(float(v.strip().replace(",", "")) for v in rest.split(";") if v.strip())
if " (" in raw:
_agg, rest = raw.split(" (", 1)
rest = rest.rstrip(")")
return tuple(float(v.strip().replace(",", "")) for v in rest.split(";") if v.strip())
return (float(raw.replace(",", "")),)
except ValueError:
return ()
def _probe_metric_series(report, metric_name):
rows = _run_csv_command(
[
"ncu",
"--import",
report,
"--page",
"raw",
"--csv",
"--metrics",
metric_name,
"--print-metric-instances",
"values",
],
timeout=60,
)
if not rows or len(rows) < 3 or len(rows[0]) <= 11:
return None
header, units, row = rows[0], rows[1], rows[2]
unit = units[11] if len(units) > 11 else ""
raw = row[11] if len(row) > 11 else ""
values = _parse_metric_values(raw)
return header[11], unit, values
def collect_metric_series(report, resolved_specs):
collected = []
skipped = []
for spec in resolved_specs:
series = None
for suffix in KIND_SUFFIXES[spec.kind]:
probe = _probe_metric_series(report, f"{spec.metric}{suffix}")
if probe is None:
continue
full_metric, unit, values = probe
if len(values) > 1:
series = MetricSeries(spec.name, full_metric, spec.category, unit, values)
break
if series is None:
skipped.append(spec)
continue
collected.append(series)
return collected, skipped
def _format_value(value):
if value == 0:
return "0"
abs_value = abs(value)
if abs_value >= 1e12:
return f"{value / 1e12:.2f} T"
if abs_value >= 1e9:
return f"{value / 1e9:.2f} G"
if abs_value >= 1e6:
return f"{value / 1e6:.2f} M"
if abs_value >= 1e3:
return f"{value / 1e3:.2f} K"
if abs_value >= 1:
return f"{value:.1f}"
return f"{value:.2f}"
def _format_with_unit(value, unit):
if not unit:
return _format_value(value)
return f"{_format_value(value)} {unit}"
def plot_pm(report, metrics, save=False):
"""Plot curated PM metrics as shared-x subplots in a light theme."""
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
if not metrics:
print("No curated metrics had time-series data in the report.")
return
bg_fig = "#ffffff"
bg_row = "#f6f8fb"
text_primary = "#1f2937"
text_secondary = "#6b7280"
text_header = "#111827"
grid_color = "#d7deea"
border = "#c7d0dd"
wave_colors = {
"Overview": "#7c8aa5",
"SM": "#4f87c2",
"L1": "#2f9d8f",
"L2": "#dd8452",
"DRAM": "#c95d63",
"Interconnect": "#8c6bb1",
}
category_rank = {category: index for index, category in enumerate(CATEGORY_ORDER)}
metrics = sorted(metrics, key=lambda item: (category_rank.get(item.category, 99), item.name))
row_h = 0.55
label_w = 3.6
plot_w = 14.0
fig_w = label_w + plot_w
fig_h = row_h * len(metrics) + 0.6
fig = plt.figure(figsize=(fig_w, fig_h), facecolor=bg_fig)
gs = GridSpec(
len(metrics),
1,
figure=fig,
left=label_w / fig_w,
right=0.97,
top=1 - 0.45 / fig_h,
bottom=0.35 / fig_h,
hspace=0.18,
)
axes = [fig.add_subplot(gs[i, 0]) for i in range(len(metrics))]
prev_category = None
for idx, metric in enumerate(metrics):
ax = axes[idx]
values = np.array(metric.values)
x = np.arange(len(values))
wave_color = wave_colors.get(metric.category, "#5b9bd5")
ax.set_facecolor(bg_row)
ax.fill_between(x, values, alpha=0.35, color=wave_color, linewidth=0)
ax.plot(x, values, linewidth=0.8, color=wave_color)
ax.set_xlim(0, len(values) - 1)
if metric.unit == "%":
ax.set_ylim(0, 100)
else:
ymax = np.max(values)
ax.set_ylim(0, ymax * 1.15 if ymax > 0 else 1)
ax.grid(True, axis="both", color=grid_color, linewidth=0.5, alpha=0.85)
ax.tick_params(axis="both", colors=text_secondary, labelsize=6, length=0)
if idx < len(metrics) - 1:
ax.tick_params(axis="x", labelbottom=False)
else:
ax.set_xlabel("Sample Index", fontsize=8, color=text_secondary)
ymin_v, ymax_v = ax.get_ylim()
ax.set_yticks([ymin_v, ymax_v])
ax.set_yticklabels([_format_value(ymin_v), _format_value(ymax_v)], fontsize=6, color=text_secondary)
peak = np.max(values)
ax.text(
1.005,
0.5,
_format_with_unit(peak, metric.unit),
transform=ax.transAxes,
fontsize=7,
color=text_secondary,
va="center",
ha="left",
family="monospace",
)
for spine in ax.spines.values():
spine.set_color(border)
spine.set_linewidth(0.5)
if metric.category != prev_category:
cat_y = ax.get_position().y1 + 0.008
fig.text(
0.005,
cat_y,
f" {metric.category}",
fontsize=8.5,
fontweight="bold",
color=text_header,
va="bottom",
family="sans-serif",
transform=fig.transFigure,
bbox=dict(boxstyle="square,pad=0.15", facecolor="#e9eef5", edgecolor="none"),
)
prev_category = metric.category
label_y = (ax.get_position().y0 + ax.get_position().y1) / 2
fig.text(
label_w / fig_w - 0.012,
label_y,
metric.name,
fontsize=7.5,
color=text_primary,
va="center",
ha="right",
family="sans-serif",
transform=fig.transFigure,
)
fig.text(
0.5,
1 - 0.15 / fig_h,
f"PM Sampling - {report}",
fontsize=11,
fontweight="bold",
color=text_header,
ha="center",
va="top",
family="sans-serif",
transform=fig.transFigure,
)
if save:
out_path = report.replace(".ncu-rep", ".pm_sampling.png")
fig.savefig(out_path, dpi=150, facecolor=bg_fig, bbox_inches="tight", pad_inches=0.2)
print(f"Saved: {out_path}")
plt.close(fig)
else:
plt.show()
def main():
parser = argparse.ArgumentParser(description="NCU PM Sampling plotter")
parser.add_argument("report", nargs="?", default="mega-moe-kk.3.ncu-rep", help="Path to .ncu-rep file")
parser.add_argument("--chip", default="gb100", help="Chip name used for `ncu --query-metrics`")
parser.add_argument("--interactive", action="store_true", help="Open an interactive Qt window instead of saving a PNG")
args = parser.parse_args()
if args.interactive:
matplotlib.use("QtAgg")
else:
matplotlib.use("Agg")
resolved_specs, missing_specs = resolve_metric_specs(args.chip)
if missing_specs:
print(f"Skipped {len(missing_specs)} curated metrics not available on {args.chip}.")
for spec in missing_specs:
print(f" missing: {spec.name} -> {spec.metric}")
metric_series, skipped_specs = collect_metric_series(args.report, resolved_specs)
if skipped_specs:
print(f"Skipped {len(skipped_specs)} curated metrics with no time-series data in {args.report}.")
for spec in skipped_specs:
print(f" no series: {spec.name} -> {spec.metric}")
plot_pm(args.report, metric_series, save=not args.interactive)
if __name__ == "__main__":
main()

89
scripts/run_ncu_mega_moe.sh Executable file
View File

@@ -0,0 +1,89 @@
#!/bin/bash
set -e
# parse num-processes, output_dir and separate python args
num_processes=8
output_dir=work
python_args=()
for ((arg_idx = 1; arg_idx <= $#; ++arg_idx)); do
arg="${!arg_idx}"
case "$arg" in
--num-processes)
python_args+=("$arg")
if ((arg_idx < $#)); then
((arg_idx++))
num_processes="${!arg_idx}"
python_args+=("$num_processes")
fi
;;
-h|--help)
echo "Usage: $0 [--num-processes N] [--output DIR] [python args...]"
exit 0
;;
--num-processes=*)
num_processes="${arg#*=}"
python_args+=("$arg")
;;
-o|--output)
if ((arg_idx < $#)); then
((arg_idx++))
output_dir="${!arg_idx}"
fi
;;
--output=*)
output_dir="${arg#*=}"
;;
*)
python_args+=("$arg")
;;
esac
done
echo "Python Args: ${python_args[*]}"
echo "Num Processes: $num_processes"
echo "Output Dir: $output_dir"
mkdir -p $output_dir
export DG_JIT_WITH_LINEINFO=1 # for source counters
echo "Warm up JIT cache"
python tests/test_mega_moe.py --ncu-profile-only "${python_args[@]}"
sleep 2
ncu_args=(
--config-file off
--force-overwrite
--kernel-name sm100_fp8_fp4_mega_moe_impl
--import-source yes
--replay-mode application
--section PmSampling
--section SourceCounters
--rule LocalMemoryUsage
--launch-skip 0
--launch-count 1
--lockstep-kernel-launch
--communicator tcp
--clock-control none
--pm-sampling-interval 1000
--pm-sampling-max-passes 1
--disable-pm-warp-sampling
--communicator-tcp-num-peers "$num_processes"
--kill yes
--app-replay-buffer memory
)
echo "Run Job"
for ((i = 0; i < num_processes; ++i)); do
ncu ${ncu_args[@]} -o "${output_dir%/}/mega-moe.$i" \
python tests/test_mega_moe.py \
--local-rank-idx=$i \
--ncu-profile-only \
"${python_args[@]}" &
done
echo "Waiting"
wait
echo "Done"

View File

@@ -253,36 +253,61 @@ def test_paged_mqa_logits():
def enumerate_paged_mqa_logits():
arch_major = get_arch_major()
for is_fp4 in ((True, False) if arch_major == 10 else (False, )):
for logits_dtype in (torch.float, torch.bfloat16):
for block_kv in ((32, 64) if arch_major == 10 else (64, )):
for use_2d_context_lens, clean_logits in [(True, False)]:
for batch_size in (256, ):
for next_n in (1, 2, 4, 5, 6) if arch_major == 10 else (1, 2):
for num_heads, head_dim in [(64, 128)]:
for avg_kv in (8192, 32768):
yield is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, num_heads, head_dim, avg_kv
for is_varlen in ((True, False) if arch_major == 10 else (False, )):
for is_fp4 in ((True, False) if arch_major == 10 else (False, )):
for logits_dtype in (torch.float, torch.bfloat16):
for block_kv in ((32, 64) if arch_major == 10 else (64, )):
for use_2d_context_lens, clean_logits in [(True, False)]:
for batch_size in (256, ):
for next_n in ((1, ) if is_varlen else ((1, 2, 4, 5, 6) if arch_major == 10 else (1, 2))):
for max_tokens_per_batch in ((1, 4, 10) if is_varlen else (1, )):
for num_heads, head_dim in [(64, 128)]:
for avg_kv in (8192, 32768):
yield is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv
print('Testing FP8/FP4 Paged MQA Logits:')
max_model_len = 111 * 1024
num_total_blocks = max_model_len * 5
for is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, num_heads, head_dim, avg_kv in enumerate_paged_mqa_logits():
for is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv in enumerate_paged_mqa_logits():
# Varlen: flatten raw_batch_size sequences with variable tokens into (batch_size, 1, ...)
raw_batch_size, raw_next_n = batch_size, next_n
if is_varlen:
tokens_per_seq = torch.randint(1, max_tokens_per_batch + 1, (raw_batch_size,), device='cuda', dtype=torch.int)
indices = torch.arange(raw_batch_size, device='cuda', dtype=torch.int).repeat_interleave(tokens_per_seq)
batch_size, next_n = tokens_per_seq.sum().item(), 1
else:
tokens_per_seq, indices = None, None
# Generate random inputs
q = torch.randn((batch_size, next_n, num_heads, head_dim), device='cuda', dtype=torch.bfloat16)
kv_cache = torch.randn((num_total_blocks, block_kv, 1, head_dim), device='cuda', dtype=torch.bfloat16)
weights = torch.randn((batch_size * next_n, num_heads), device='cuda', dtype=torch.float)
context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size,), device='cuda', dtype=torch.int)
context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (raw_batch_size,), device='cuda', dtype=torch.int)
# Assign block tables
num_blocks_per_query = ceil_div(context_lens, block_kv)
block_table = torch.empty((batch_size, num_blocks_per_query.max().item()), device='cuda', dtype=torch.int)
if is_varlen:
max_ctx_len_per_seq = context_lens + (tokens_per_seq - 1)
else:
max_ctx_len_per_seq = context_lens
# Assign block tables (per-sequence, sized by the largest ctx_len within the sequence)
seq_sum_lens = context_lens.sum().item()
num_blocks_per_query = ceil_div(max_ctx_len_per_seq, block_kv)
block_table = torch.empty((raw_batch_size, num_blocks_per_query.max().item()), device='cuda', dtype=torch.int)
block_idx_pool = torch.randperm(num_total_blocks, device='cuda', dtype=torch.int)
offset = 0
for i, num_blocks in enumerate(num_blocks_per_query.tolist()):
block_table[i, :num_blocks] = block_idx_pool[offset : offset + num_blocks]
offset += num_blocks
if is_varlen:
context_lens = context_lens.repeat_interleave(tokens_per_seq)
offsets_within_seq = torch.cat([
torch.arange(n.item(), device='cuda', dtype=torch.int)
for n in tokens_per_seq
])
context_lens = context_lens + offsets_within_seq
block_table = block_table.repeat_interleave(tokens_per_seq, dim=0)
# Calculate reference logits
ref_logits = ref_paged_mqa_logits(q, kv_cache, weights, context_lens, block_table, max_model_len, use_2d_context_lens)
@@ -304,9 +329,14 @@ def test_paged_mqa_logits():
# Prepare masks and context lengths with NextN
positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1)
if use_2d_context_lens:
context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int()
# Ensure last token matches actual length
context_lens_nextn[:, -1] = context_lens
if is_varlen:
# Varlen: context_lens is already per-token (shape [total_tokens]);
# just reshape to (total_tokens, 1) so each token keeps its own ctx_len.
context_lens_nextn = context_lens.view(-1, 1)
else:
context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int()
# Ensure last token matches actual length
context_lens_nextn[:, -1] = context_lens
ref_neginf_mask = ~(positions < context_lens_nextn.view(-1, 1))
else:
context_lens_nextn = context_lens
@@ -318,8 +348,9 @@ def test_paged_mqa_logits():
kernel_kwargs = dict(
q=q_in, kv_cache=kv_in, weights=weights,
context_lens=context_lens_nextn, block_table=block_table,
schedule_meta=deep_gemm.get_paged_mqa_logits_metadata(context_lens_nextn, block_kv, deep_gemm.get_num_sms()),
max_context_len=max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype
schedule_meta=deep_gemm.get_paged_mqa_logits_metadata(context_lens_nextn, block_kv, deep_gemm.get_num_sms(), indices=indices),
max_context_len=max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype,
indices=indices,
)
logits = deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs)
@@ -342,11 +373,15 @@ def test_paged_mqa_logits():
sum_lens = context_lens.sum().item()
tflops_calc = 2 * sum_lens * next_n * num_heads * head_dim / 1e12
kv_bytes_per_token = head_dim / (2 if is_fp4 else 1) + 4
total_bytes = count_bytes(q, weights) + sum_lens * kv_bytes_per_token + (sum_lens * next_n * logits_dtype.itemsize)
# KV is read once per sequence; for varlen sum_lens overcounts (per-token), so use seq_sum_lens
kv_sum_lens = seq_sum_lens if is_varlen else sum_lens
total_bytes = count_bytes(q, weights) + kv_sum_lens * kv_bytes_per_token + (sum_lens * next_n * logits_dtype.itemsize)
t, clean_t = bench_kineto(lambda: deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs), ('paged_mqa_logits', 'clean_logits'))
print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, BLOCK_KV={block_kv}, BSZ={batch_size:3}, NextN={next_n:1}, H={num_heads:2}, D={head_dim:2}, L={avg_kv:6}: '
print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, BLOCK_KV={block_kv}, BSZ={raw_batch_size:3}, NextN={raw_next_n:1}, H={num_heads:2}, D={head_dim:2}, L={avg_kv:6}: '
f'{tflops_calc / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, {total_bytes / t / 1e9:4.0f} GB/s', end='')
if is_varlen:
print(f' | Varlen, MaxTPB={max_tokens_per_batch}, NumTokens={batch_size}', end='')
print(f' | clean: {clean_t*1e6:3.0f} us' if clean_logits else '')
print()

View File

@@ -9,24 +9,28 @@ from typing import Tuple
import deep_gemm
from deep_gemm.utils import per_token_cast_to_fp4, per_token_cast_to_fp8
from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather
from deep_gemm.testing import bench, bench_kineto, calc_diff
from deep_gemm.testing import bench_kineto
# Load legacy implements from third-party
# noinspection PyBroadException
try:
import deep_ep
import importlib.util
from tilelang.profiler.bench import do_bench
spec = importlib.util.spec_from_file_location(
'tilelang_ops',
os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py'))
tilelang_ops = importlib.util.module_from_spec(spec)
sys.modules['tilelang_ops'] = tilelang_ops
spec.loader.exec_module(tilelang_ops)
is_legacy_loaded = True
except Exception as ex:
print(f'Failed to load legacy code: {ex}, skip baseline benchmarking')
is_legacy_loaded = False
def import_baseline():
# Load legacy implements from third-party
deep_ep, tilelang_ops, do_bench, is_legacy_loaded = None, None, None, False
# noinspection PyBroadException
try:
import deep_ep
import importlib.util
from tilelang.profiler.bench import do_bench
spec = importlib.util.spec_from_file_location(
'tilelang_ops',
os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py'))
tilelang_ops = importlib.util.module_from_spec(spec)
sys.modules['tilelang_ops'] = tilelang_ops
spec.loader.exec_module(tilelang_ops)
is_legacy_loaded = True
except Exception as ex:
dist_print(f'Failed to load legacy code: {ex}, skip baseline benchmarking', once_in_node=True)
dist_print(once_in_node=True)
return deep_ep, tilelang_ops, do_bench, is_legacy_loaded
# TODO: skip the test for SM90
@@ -51,29 +55,13 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden
)
dist_print('Config:', once_in_node=True)
dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True)
dist_print(f' > Hidden: {hidden}', once_in_node=True)
dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True)
dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True)
dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True)
dist_print(once_in_node=True)
# Non-overlapped baseline: EP dispatch + GEMM + EP combine
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
ep_buffer = deep_ep.ElasticBuffer(
group,
num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden,
num_topk=num_topk, use_fp8_dispatch=True,
explicitly_destroy=True,
allow_multiple_reduction=False,
gpu_timeout_secs=10, cpu_timeout_secs=30
) if is_legacy_loaded else None
# Create inputs
# noinspection PyGlobalUndefined
def create_inputs():
global x, topk_idx, topk_weights, l1_weights, l2_weights, transformed_l1_weights, transformed_l2_weights
global cumulative_local_expert_recv_stats_fused
global cumulative_local_expert_recv_stats_baseline
x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
l1_weights = torch.randn(
(num_experts_per_rank, intermediate_hidden * 2, hidden), dtype=torch.bfloat16, device='cuda')
@@ -81,6 +69,9 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
(num_experts_per_rank, hidden, intermediate_hidden), dtype=torch.bfloat16, device='cuda')
scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda')
topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)
cumulative_local_expert_recv_stats_fused = torch.randint(
0, 100, (num_experts_per_rank, ), dtype=torch.int, device='cuda')
cumulative_local_expert_recv_stats_baseline = cumulative_local_expert_recv_stats_fused.clone()
if args.masked_ratio > 0:
rand_mask = torch.rand_like(topk_idx, dtype=torch.float)
topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1)
@@ -109,12 +100,67 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
l2_weights = cast_grouped_weights_to_fp4(l2_weights)
transformed_l1_weights, transformed_l2_weights = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights)
# Run fused mega MoE
# NOTES: copy x into buffer before each call because debug mode zeros the entire buffer
def run_fused():
buffer.x[:num_tokens].copy_(x[0])
buffer.x_sf[:num_tokens].copy_(x[1])
buffer.topk_idx[:num_tokens].copy_(topk_idx)
buffer.topk_weights[:num_tokens].copy_(topk_weights)
y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
# noinspection PyTypeChecker
deep_gemm.fp8_fp4_mega_moe(
y,
transformed_l1_weights, transformed_l2_weights,
buffer,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_fused,
activation_clamp=args.activation_clamp,
fast_math=bool(args.fast_math)
)
return y, cumulative_local_expert_recv_stats_fused
dist_print('Config:', once_in_node=True)
dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True)
dist_print(f' > Hidden: {hidden}', once_in_node=True)
dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True)
dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True)
dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True)
dist_print(once_in_node=True)
# Only do NCU profiling
if args.ncu_profile_only:
create_inputs()
dist_print(f'Run fused kernel:', once_in_node=True)
run_fused()
dist_print(f' > Done, exiting', once_in_node=True)
# Destroy and exit
dist.barrier()
buffer.destroy()
dist.destroy_process_group()
return
# Non-overlapped baseline: EP dispatch + GEMM + EP combine
deep_ep, tilelang_ops, tilelang_bench, is_legacy_loaded = import_baseline()
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
ep_buffer = deep_ep.ElasticBuffer(
group,
num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden,
num_topk=num_topk, use_fp8_dispatch=True,
explicitly_destroy=True,
allow_multiple_reduction=False,
gpu_timeout_secs=10, cpu_timeout_secs=30
) if is_legacy_loaded else None
def run_baseline():
recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch(
x, topk_idx=topk_idx, topk_weights=topk_weights,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_baseline,
num_experts=num_experts, expert_alignment=alignment,
do_cpu_sync=False, do_handle_copy=False,
do_expand=True, use_tma_aligned_col_major_sf=True
do_expand=True, use_tma_aligned_col_major_sf=True,
)
n = recv_x[0].size(0)
l1_y = torch.empty((n, intermediate_hidden * 2), dtype=torch.bfloat16, device='cuda')
@@ -138,26 +184,7 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(
l1_y, l2_weights, l2_y, handle.psum_num_recv_tokens_per_expert,
use_psum_layout=True, recipe=(1, 1, 32))
return ep_buffer.combine(l2_y, handle=handle)[0]
# Run fused mega MoE
# NOTES: copy x into buffer before each call because debug mode zeros the entire buffer
def run_fused():
buffer.x[:num_tokens].copy_(x[0])
buffer.x_sf[:num_tokens].copy_(x[1])
buffer.topk_idx[:num_tokens].copy_(topk_idx)
buffer.topk_weights[:num_tokens].copy_(topk_weights)
y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
# noinspection PyTypeChecker
deep_gemm.fp8_fp4_mega_moe(
y,
transformed_l1_weights, transformed_l2_weights,
buffer,
activation_clamp=args.activation_clamp,
fast_math=bool(args.fast_math)
)
return y
return ep_buffer.combine(l2_y, handle=handle)[0], cumulative_local_expert_recv_stats_baseline
# Check correctness (must be bitwise identical)
num_correctness_tests = 1 if args.num_correctness_tests is None else args.num_correctness_tests
@@ -166,34 +193,36 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
dist_print('Running correctness tests:', once_in_node=True)
for i in range(num_correctness_tests):
create_inputs()
assert torch.equal(run_fused(), run_baseline())
for fused_result, baseline_result in zip(run_fused(), run_baseline()):
assert torch.equal(fused_result, baseline_result)
if (i + 1) % 100 == 0 or i == num_correctness_tests - 1:
dist_print(f' > Correctness test #{i + 1}/{args.num_correctness_tests} passed', once_in_node=True)
dist_print(f' > Correctness test #{i + 1}/{num_correctness_tests} passed', once_in_node=True)
dist_print(once_in_node=True)
else:
create_inputs()
# Count local received tokens
gathered_topk_idx = uneven_all_gather(topk_idx, group=group)
num_recv_tokens = (rank_idx * num_experts_per_rank <= gathered_topk_idx) & \
(gathered_topk_idx < (rank_idx + 1) * num_experts_per_rank)
num_recv_tokens = num_recv_tokens.sum().item()
gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | \
(gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1
num_recv_tokens = (gathered_topk_idx != -1).sum().item()
# Benchmark
t_fused = bench_kineto(
run_fused, 'mega_moe',
barrier=lambda: ep_buffer.barrier(use_comm_stream=False) if ep_buffer else dist.barrier(),
trace_path=None if not args.dump_profile_traces else f'{args.dump_profile_traces}/mega_moe_rank{rank_idx}.json')
t_baseline = do_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0
t_baseline = tilelang_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0
# TFLOPS: 3 matmuls (L1 left, L1 right, L2), each 2 * M * N * K
safe_div = lambda a, b: float('nan') if b == 0 else a / b
tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused)
# HBM bytes: weights (FP4 packed = 0.5 bytes) + activations (FP8 = 1 byte) + output (BF16 = 2 bytes)
num_touched_experts = torch.unique(gathered_topk_idx.flatten()).numel() - 1 # NOTES minus 1 to exclude "-1"
num_hbm_bytes = (
num_experts_per_rank * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4)
num_experts_per_rank * hidden * intermediate_hidden // 2 + # L2 weights (FP4)
num_touched_experts * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4)
num_touched_experts * hidden * intermediate_hidden // 2 + # L2 weights (FP4)
num_recv_tokens * hidden + # L1 acts read (FP8)
num_recv_tokens * intermediate_hidden + # L1 output write (FP8)
num_recv_tokens * intermediate_hidden + # L2 acts read (FP8)
@@ -230,7 +259,9 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test PyTorch symmetric memory')
# Resource settings
parser.add_argument('--ncu-profile-only', action='store_true', help='Only run profiling without correctness test')
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
# Model settings