diff --git a/README.md b/README.md index 07766c8..6ef705f 100644 --- a/README.md +++ b/README.md @@ -9,15 +9,15 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert ## News - 2026.04.16: Mega MoE, FP8xFP4 GEMM, FP4 Indexer, PDL, faster JIT compilation and more. - - Performance comparison will be posted later. - - Please see [#304](https://github.com/deepseek-ai/DeepGEMM/pull/304) for more details. + - Please see [#304](https://github.com/deepseek-ai/DeepGEMM/pull/304) for more details. + - For Mega MoE benchmarks, refer to [#316](https://github.com/deepseek-ai/DeepGEMM/pull/316). - 2025.09.28: DeepGEMM now supports scoring kernels (weighted ReLU MQA logits) for the lightning indexer for DeepSeek v3.2. - - Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details. + - Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details. - 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module. - - NVRTC and post-compilation SASS optimization are all disabled. - - NVRTC will be supported later. - - As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported. - - Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details. + - NVRTC and post-compilation SASS optimization are all disabled. + - NVRTC will be supported later. + - As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported. + - Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details. - 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details. - 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). - 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. @@ -30,9 +30,9 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - Python 3.8 or higher - Compilers with C++20 support - CUDA Toolkit: - - CUDA 12.3 or higher for SM90 - - **We highly recommend 12.9 or higher for the best performance** - - CUDA 12.9 or higher for SM100 + - CUDA 12.3 or higher for SM90 + - **We highly recommend 12.9 or higher for the best performance** + - CUDA 12.9 or higher for SM100 - PyTorch 2.1 or higher - CUTLASS 4.0 or higher (could be cloned by Git submodule) - `{fmt}` library (could be cloned by Git submodule) @@ -159,30 +159,30 @@ The library provides some utility functions besides the above kernels: The library also provides some environment variables, which may be useful: - General - - `DG_JIT_DEBUG`: `0` or `1`, print JIT debugging information, `0` by default - - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default + - `DG_JIT_DEBUG`: `0` or `1`, print JIT debugging information, `0` by default + - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default - JIT cache - - `DG_JIT_CACHE_DIR`: string, cache directory for compiled kernels, `$HOME/.deep_gemm` by default + - `DG_JIT_CACHE_DIR`: string, cache directory for compiled kernels, `$HOME/.deep_gemm` by default - Compiler selection - - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC (faster compilation, may have lower performance for some cases), `0` by default - - `DG_JIT_NVCC_COMPILER`: string, NVCC compiler path; defaults to `torch.utils.cpp_extension.CUDA_HOME` - - `DG_JIT_CPP_STANDARD`: integer, C++ standard version, `20` by default + - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC (faster compilation, may have lower performance for some cases), `0` by default + - `DG_JIT_NVCC_COMPILER`: string, NVCC compiler path; defaults to `torch.utils.cpp_extension.CUDA_HOME` + - `DG_JIT_CPP_STANDARD`: integer, C++ standard version, `20` by default - Compiler output - - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print compilation commands, `0` by default - - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS output, `0` by default - - `DG_JIT_PTXAS_CHECK`: `0` or `1`, assert no local memory usage in compiled kernels, `0` by default - - `DG_JIT_PRINT_LOAD_TIME`: `0` or `1`, print kernel load time, `0` by default + - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print compilation commands, `0` by default + - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS output, `0` by default + - `DG_JIT_PTXAS_CHECK`: `0` or `1`, assert no local memory usage in compiled kernels, `0` by default + - `DG_JIT_PRINT_LOAD_TIME`: `0` or `1`, print kernel load time, `0` by default - Debug and profiling - - `DG_JIT_WITH_LINEINFO`: `0` or `1`, embed source line info for profiling tools, `0` by default - - `DG_JIT_DUMP_ASM`: `0` or `1`, dump both PTX and SASS, `0` by default - - `DG_JIT_DUMP_PTX`: `0` or `1`, dump PTX output, `0` by default - - `DG_JIT_DUMP_SASS`: `0` or `1`, dump SASS output, `0` by default - - `DG_COMM_KERNEL_DEBUG`: `0` or `1`, zero symmetric buffer before each Mega MoE call for debugging, `0` by default - - `DG_USE_NVIDIA_TOOLS`: `0` or `1`, skip internal profiling when running under external NVIDIA tools, `0` by default + - `DG_JIT_WITH_LINEINFO`: `0` or `1`, embed source line info for profiling tools, `0` by default + - `DG_JIT_DUMP_ASM`: `0` or `1`, dump both PTX and SASS, `0` by default + - `DG_JIT_DUMP_PTX`: `0` or `1`, dump PTX output, `0` by default + - `DG_JIT_DUMP_SASS`: `0` or `1`, dump SASS output, `0` by default + - `DG_COMM_KERNEL_DEBUG`: `0` or `1`, zero symmetric buffer before each Mega MoE call for debugging, `0` by default + - `DG_USE_NVIDIA_TOOLS`: `0` or `1`, skip internal profiling when running under external NVIDIA tools, `0` by default - Build options - - `DG_SKIP_CUDA_BUILD`: `0` or `1`, skip CUDA extension build during installation, `0` by default - - `DG_FORCE_BUILD`: `0` or `1`, force local build instead of downloading pre-built wheels, `0` by default - - `DG_JIT_USE_RUNTIME_API`: `0` or `1`, use CUDA Runtime API for kernel loading (requires CUDA runtime >= 12.8), `0` by default + - `DG_SKIP_CUDA_BUILD`: `0` or `1`, skip CUDA extension build during installation, `0` by default + - `DG_FORCE_BUILD`: `0` or `1`, force local build instead of downloading pre-built wheels, `0` by default + - `DG_JIT_USE_RUNTIME_API`: `0` or `1`, use CUDA Runtime API for kernel loading (requires CUDA runtime >= 12.8), `0` by default For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp index 38ff225..505b0c0 100644 --- a/csrc/apis/attention.hpp +++ b/csrc/apis/attention.hpp @@ -190,12 +190,13 @@ static torch::Tensor fp8_fp4_mqa_logits(const std::tuple& indices) { // NOTES: Only 2D context lens is supported for now DG_HOST_ASSERT(context_lens.dim() == 2); const bool is_context_lens_2d = true; const int batch_size = context_lens.size(0); const int next_n = context_lens.size(1); + const bool is_varlen = indices.has_value(); DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); DG_HOST_ASSERT(context_lens.is_contiguous()); @@ -204,9 +205,16 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_ // Dispatch implementation const auto arch_major = device_runtime->get_arch_major(); - if (arch_major == 9 or arch_major == 10) { + if (is_varlen) { + const auto& indices_tensor = indices.value(); + DG_HOST_ASSERT(arch_major == 10 and next_n == 1 and (block_kv == 64 or block_kv == 32)); + DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size); + DG_HOST_ASSERT(indices_tensor.is_contiguous()); + DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt); + smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, true, indices_tensor.data_ptr()); + } else if (arch_major == 9 or arch_major == 10) { DG_HOST_ASSERT(block_kv == 64 or (arch_major == 10 and block_kv == 32)); - smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d); + smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, false, nullptr); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } @@ -222,7 +230,8 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple& indices) { const auto [q_fp, q_sf] = q; const bool is_fp4 = q_sf.has_value(); @@ -321,6 +330,17 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tupleget_arch_major(); + const auto indices_tensor = indices.value_or(torch::Tensor()); + if (is_varlen) { + DG_HOST_ASSERT(arch_major == 10 and next_n == 1); + DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size); + DG_HOST_ASSERT(indices_tensor.is_contiguous()); + DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt); + } + // Check schedule metadata auto [_schedule_meta_size, _meta_info_size] = get_shape<2>(schedule_meta); DG_HOST_ASSERT(_schedule_meta_size == num_sms + 1 and _meta_info_size == 2); @@ -344,15 +364,14 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tupleget_arch_major(); if (is_fp4 and arch_major == 10) { - sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, schedule_meta, + sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta, logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d, - aligned_max_context_len, block_table_stride, num_sms, split_kv); + is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv); } else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) { - smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, schedule_meta, + smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta, logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d, - aligned_max_context_len, block_table_stride, num_sms, split_kv); + is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } @@ -386,10 +405,11 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, const torch::Tensor& block_table, const torch::Tensor& schedule_meta, const int& max_context_len, - const bool& clean_logits) { + const bool& clean_logits, + const std::optional& indices) { return fp8_fp4_paged_mqa_logits(std::make_tuple(q, std::nullopt), fused_kv_cache, weights, context_lens, block_table, schedule_meta, - max_context_len, clean_logits, torch::kFloat); + max_context_len, clean_logits, torch::kFloat, indices); } #endif @@ -407,13 +427,15 @@ static void register_apis(pybind11::module_& m) { py::arg("max_seqlen_k") = 0, py::arg("logits_dtype") = torch::kFloat32); m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata, - py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms")); + py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"), + py::arg("indices") = std::nullopt); m.def("fp8_fp4_paged_mqa_logits", &fp8_fp4_paged_mqa_logits, py::arg("q"), py::arg("kv_cache"), py::arg("weights"), py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"), py::arg("max_context_len"), py::arg("clean_logits") = false, - py::arg("logits_dtype") = torch::kFloat32); + py::arg("logits_dtype") = torch::kFloat32, + py::arg("indices") = std::nullopt); // Legacy API m.def("fp8_mqa_logits", &fp8_mqa_logits, py::arg("q"), py::arg("kv"), py::arg("weights"), @@ -423,7 +445,8 @@ static void register_apis(pybind11::module_& m) { m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits, py::arg("q"), py::arg("kv_cache"), py::arg("weights"), py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"), - py::arg("max_context_len"), py::arg("clean_logits") = false); + py::arg("max_context_len"), py::arg("clean_logits") = false, + py::arg("indices") = std::nullopt); #endif } diff --git a/csrc/apis/mega.hpp b/csrc/apis/mega.hpp index 8129b79..efc3a78 100644 --- a/csrc/apis/mega.hpp +++ b/csrc/apis/mega.hpp @@ -11,6 +11,10 @@ namespace deep_gemm::mega { +static int get_token_alignment_for_mega_moe() { + return layout::kLCMCandidateBlockM; +} + static std::tuple(const torch::Tensor&)>> get_symm_buffer_size_for_mega_moe( const int& num_ranks, const int& num_experts, @@ -20,8 +24,7 @@ get_symm_buffer_size_for_mega_moe( DG_HOST_ASSERT(num_experts % num_ranks == 0); // Workspace bytes - const auto block_m = get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk); - const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk, block_m); + const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk); // Layouts const auto fp8_token_layout = layout::Data(hidden); @@ -49,14 +52,20 @@ get_symm_buffer_size_for_mega_moe( // Buffer configs const auto num_max_pool_tokens = static_cast(workspace.num_max_pool_tokens); - const auto num_padded_sf_pool_tokens = layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m); + int num_max_padded_sf_pool_tokens = 0; + for (int block_m: layout::kCandidateBlockM) { + num_max_padded_sf_pool_tokens = std::max( + num_max_padded_sf_pool_tokens, + layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m) + ); + } // L1 input buffer const auto l1_token_buffer = layout::Buffer( fp8_token_layout, 1, num_max_pool_tokens, input_topk_weights_buffer.get_end_ptr()); const auto l1_sf_buffer = layout::Buffer( - fp8_sf_layout, 1, num_padded_sf_pool_tokens, + fp8_sf_layout, 1, num_max_padded_sf_pool_tokens, l1_token_buffer.get_end_ptr()); const auto l1_topk_weights_buffer = layout::Buffer( l1_topk_weights_layout, 1, num_max_pool_tokens, @@ -67,7 +76,7 @@ get_symm_buffer_size_for_mega_moe( fp8_intermediate_token_layout, 1, num_max_pool_tokens, l1_topk_weights_buffer.get_end_ptr()); const auto l2_sf_buffer = layout::Buffer( - fp8_intermediate_sf_layout, 1, num_padded_sf_pool_tokens, + fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens, l2_token_buffer.get_end_ptr()); // Combine input buffer: BF16 tokens for cross-rank combine @@ -77,7 +86,7 @@ get_symm_buffer_size_for_mega_moe( // Check SF buffer requirements DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); - DG_HOST_ASSERT(num_padded_sf_pool_tokens % 4 == 0); + DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); // Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer // NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major @@ -104,8 +113,8 @@ get_symm_buffer_size_for_mega_moe( torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); auto l1_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), - {num_padded_sf_pool_tokens, hidden / 128}, - {1, num_padded_sf_pool_tokens}, + {num_max_padded_sf_pool_tokens, hidden / 128}, + {1, num_max_padded_sf_pool_tokens}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); auto l2_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), @@ -113,8 +122,8 @@ get_symm_buffer_size_for_mega_moe( torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); auto l2_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), - {num_padded_sf_pool_tokens, intermediate_hidden / 128}, - {1, num_padded_sf_pool_tokens}, + {num_max_padded_sf_pool_tokens, intermediate_hidden / 128}, + {1, num_max_padded_sf_pool_tokens}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf); }; @@ -123,8 +132,9 @@ get_symm_buffer_size_for_mega_moe( static void fp8_fp4_mega_moe( const torch::Tensor& y, - const std::tuple& l1_weights_, - const std::tuple& l2_weights_, + const std::tuple& l1_weights_tuple, + const std::tuple& l2_weights_tuple, + const std::optional& cumulative_local_expert_recv_stats, const torch::Tensor& sym_buffer, const std::vector& sym_buffer_ptrs, const int& rank_idx, const int& num_max_tokens_per_rank, @@ -132,9 +142,10 @@ static void fp8_fp4_mega_moe( const std::tuple& recipe, const std::string& activation, const std::optional& activation_clamp_opt, - const bool& fast_math) { - const auto [l1_weights, l1_weights_sf] = l1_weights_; - const auto [l2_weights, l2_weights_sf] = l2_weights_; + const bool& fast_math +) { + const auto [l1_weights, l1_weights_sf] = l1_weights_tuple; + const auto [l2_weights, l2_weights_sf] = l2_weights_tuple; // Config checks const auto num_tokens = static_cast(y.size(0)); @@ -161,13 +172,20 @@ static void fp8_fp4_mega_moe( DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden); DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous()); - // Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment + // Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment constexpr int kGranMN = 1, kGranK = 32; check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK, num_experts_per_rank, true, false, torch::kInt); check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK, num_experts_per_rank, true, false, torch::kInt); + // Check stats counter + if (cumulative_local_expert_recv_stats.has_value()) { + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous()); + } + // Check buffer bytes const auto num_ranks = static_cast(sym_buffer_ptrs.size()); const auto num_experts_ = num_experts_per_rank * num_ranks; @@ -175,7 +193,7 @@ static void fp8_fp4_mega_moe( num_ranks, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, - true, "swiglu"); + true, activation); DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast(num_required_bytes)); DG_HOST_ASSERT(num_experts == num_experts_); @@ -189,6 +207,7 @@ static void fp8_fp4_mega_moe( l2_acts, l2_acts_sf, l1_weights, l2_weights, l1_weights_sf, l2_weights_sf, + cumulative_local_expert_recv_stats, sym_buffer_ptrs, rank_idx, num_max_tokens_per_rank, num_experts_per_rank, @@ -207,7 +226,7 @@ static void fp8_fp4_mega_moe( static void register_apis(pybind11::module_& m) { #if DG_TENSORMAP_COMPATIBLE - m.def("get_block_m_for_mega_moe", &get_block_m_for_mega_moe); + m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe); m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe); m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe); #endif diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index 2caa6f1..b1ba6bd 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -55,38 +55,68 @@ struct MegaMoEConfig { } }; -static int get_block_m_for_mega_moe(const int& num_ranks, const int& num_experts, - const int& num_max_tokens_per_rank, const int& num_topk) { - // TODO: compute based on configs - return 192; +static std::tuple get_block_config_for_mega_moe( + const int& num_ranks, const int& num_experts, + const int& num_max_tokens_per_rank, const int& num_topk, + const int& num_tokens) { + const auto& [cluster_size, block_m, store_block_m, num_epilogue_warpgroups] = [&]() -> std::tuple { + float num_expected_tokens_per_expert = static_cast(num_tokens) * num_ranks * num_topk / num_experts; + if (num_expected_tokens_per_expert <= 8.5) { + // Really small token-per-expert (e.g. RL long-tail rollout), use the smallest block_m + return {2, 16, 8, 2}; + } else if (num_expected_tokens_per_expert <= 16.5) { + // Small batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 128 + return {2, 32, 16, 2}; + } else if (num_expected_tokens_per_expert <= 32.5) { + // Medium batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 256 + return {2, 64, 32, 1}; + } else if (num_expected_tokens_per_expert <= 64.5) { + // Large batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 512 + return {2, 96, 16, 2}; + } else if (num_expected_tokens_per_expert <= 96.5) { + // Medium batch size, Medium EP, decoding, e.g. 6/384 experts, EP16, bsz 256, or EP32, bsz128 + return {2, 128, 32, 2}; + } else { + // Prefill, or large EP decoding + return {2, 192, 32, 2}; + } + }(); + + // Check whether our `block_m` lies in `kCandidateBlockM` + DG_HOST_ASSERT(std::any_of( + layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs, + [=](const auto& candidate) { return candidate == block_m; }) + ); + + // Return configs + return {cluster_size, block_m, store_block_m, num_epilogue_warpgroups * 128}; } static int get_num_experts_per_wave_for_mega_moe( const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) { + + float expected_tokens_per_expert = static_cast(num_tokens) * num_topk / num_experts_per_rank; + if (expected_tokens_per_expert < 1) { + // Most experts don't have tokens, calculate all experts at once + return num_experts_per_rank; + } + // Reduce per-expert block count by this factor since uneven routing leaves some experts with fewer tokens constexpr int kImbalanceFactor = 2; - // TODO: support num_experts_per_rank > 32 - // Find the largest divisor of num_experts_per_rank that fits in 32 as the upper bound - int max_num_experts_per_wave = std::min(32, num_experts_per_rank); - while (max_num_experts_per_wave > 1 and num_experts_per_rank % max_num_experts_per_wave != 0) - -- max_num_experts_per_wave; - // Count L1 blocks per expert assuming tokens are evenly spread across experts - const int expected_tokens_per_expert = - num_tokens * num_topk / num_experts_per_rank + 1; - const int num_m_blocks = ceil_div(expected_tokens_per_expert, block_m); - const int num_n_blocks = intermediate_hidden / block_n; + const int num_m_blocks = ceil_div(static_cast(std::ceil(expected_tokens_per_expert)), block_m); + const int num_n_blocks = (2 * intermediate_hidden) / block_n; const int num_l1_blocks_per_expert = num_m_blocks * num_n_blocks; // Pick the smallest value whose total blocks (after imbalance reduction) can keep all SMs busy int num_experts_per_wave = num_l1_blocks_per_expert > 0 ? ceil_div(kImbalanceFactor * num_sms, num_l1_blocks_per_expert) : 1; - num_experts_per_wave = std::min(num_experts_per_wave, max_num_experts_per_wave); + num_experts_per_wave = std::min(num_experts_per_wave, num_experts_per_rank); // Round up to the nearest divisor of num_experts_per_rank so every wave processes the same count - while (num_experts_per_wave < max_num_experts_per_wave and num_experts_per_rank % num_experts_per_wave != 0) + while (num_experts_per_wave < num_experts_per_rank and num_experts_per_rank % num_experts_per_wave != 0) ++ num_experts_per_wave; return num_experts_per_wave; @@ -148,18 +178,18 @@ static std::pair get_pipeline_config_for_mega_moe( static MegaMoEConfig get_mega_moe_config( const int& num_ranks, const int& num_experts, const int& num_experts_per_rank, const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, - const int& hidden, const int& intermediate_hidden) { - // Block tiling - const int block_m = get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk); + const int& hidden, const int& intermediate_hidden, + const int& num_padded_sf_pool_tokens) { + // Block config + const auto [cluster_size, block_m, store_block_m, num_epilogue_threads] = + get_block_config_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens); const int block_n = 128; const int block_k = 128; const int load_block_m = block_m / 2; const int load_block_n = block_n; - const int store_block_m = 32; const auto [sf_block_m, sf_block_n] = SM100ArchSpec::get_sf_uttcp_aligned_block_sizes(block_m, block_n, MmaKind::MXFP8FP4); const int num_max_pool_tokens = layout::get_num_max_pool_tokens( - num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank, block_m); - const int num_padded_sf_pool_tokens = layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m); + num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); // NOTES: FP8 activations and FP4 weights (unpacked to 8-bit in smem) both use 128B swizzle const int swizzle_acts_mode = 128; const int swizzle_weights_mode = 128; @@ -173,7 +203,6 @@ static MegaMoEConfig get_mega_moe_config( // Thread layout const int num_dispatch_threads = 128; const int num_non_epilogue_threads = 128; - const int num_epilogue_threads = 256; // Pipeline const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe( diff --git a/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp b/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp index a9afda2..4d91256 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp @@ -29,6 +29,7 @@ public: // Runtime arguments void* y; + int* cumulative_local_expert_recv_stats; int num_tokens; layout::SymBuffer<> sym_buffer_ptrs; @@ -91,6 +92,7 @@ static void __instantiate_kernel() {{ // TODO: optimize `args` copy DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.y, + args.cumulative_local_expert_recv_stats, args.num_tokens, args.sym_buffer_ptrs, args.tensor_map_l1_acts, @@ -112,6 +114,7 @@ static void sm100_fp8_fp4_mega_moe( const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf, const torch::Tensor& l1_weights, const torch::Tensor& l2_weights, const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf, + const std::optional cumulative_local_expert_recv_stats, const std::vector& sym_buffer_ptrs, const int& rank_idx, const int& num_max_tokens_per_rank, const int& num_experts_per_rank, @@ -122,11 +125,12 @@ static void sm100_fp8_fp4_mega_moe( ) { const auto num_ranks = static_cast(sym_buffer_ptrs.size()); const auto num_experts = num_experts_per_rank * num_ranks; + const auto num_padded_sf_pool_tokens = static_cast(l1_acts_sf.size(0)); // Heuristics const auto config = get_mega_moe_config( num_ranks, num_experts, num_experts_per_rank, - num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden); + num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden, num_padded_sf_pool_tokens); // Make tensormap constexpr int kGranK = 32; @@ -175,6 +179,11 @@ static void sm100_fp8_fp4_mega_moe( config.block_n, kGranK, num_experts_per_rank, 0); + // Stats can be optional + int* cumulative_local_expert_recv_stats_ptr = nullptr; + if (cumulative_local_expert_recv_stats.has_value()) + cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr(); + // Launch const auto num_sms = device_runtime->get_num_sms(); const SM100FP8FP4MegaMoERuntime::Args args = { @@ -186,6 +195,7 @@ static void sm100_fp8_fp4_mega_moe( .fast_math = fast_math, .config = config, .y = y.data_ptr(), + .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, .num_tokens = num_tokens, .sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx), .tensor_map_l1_acts = tensor_map_l1_acts, diff --git a/csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp index 89d1fd6..2a3288e 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp @@ -14,11 +14,13 @@ public: int aligned_batch_size; int split_kv; int num_sms; - + bool is_varlen; + int batch_size; int next_n; bool is_context_lens_2d; int* context_lens; + int* indices; int* schedule_metadata; LaunchArgs launch_args; @@ -32,10 +34,10 @@ using namespace deep_gemm; static void __instantiate_kernel() {{ auto ptr = reinterpret_cast(&sched::smxx_paged_mqa_logits_metadata< - {}, {}, {} + {}, {}, {}, {} >); }}; -)", args.aligned_batch_size, args.split_kv, args.num_sms); +)", args.aligned_batch_size, args.split_kv, args.num_sms, args.is_varlen ? "true" : "false"); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -44,6 +46,7 @@ static void __instantiate_kernel() {{ args.next_n, args.is_context_lens_2d, args.context_lens, + args.indices, args.schedule_metadata )); } @@ -53,14 +56,15 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens, const torch::Tensor& schedule_metadata, const int& batch_size, const int& next_n, const int& block_kv, const int& num_sms, - const bool& is_context_lens_2d) { + const bool& is_context_lens_2d, + const bool& is_varlen, const int* indices_ptr) { constexpr int split_kv = 256; constexpr int num_threads = 32; const int aligned_batch_size = align(batch_size, 32); DG_HOST_ASSERT(split_kv % block_kv == 0); - // Calculate shared memory size - const int smem_size = aligned_batch_size * static_cast(sizeof(int)); + // Shared memory: prefix_sum[kAlignedBatchSize] + varlen_atom_token_start/context_len[kAlignedBatchSize] + varlen_num_atoms + const int smem_size = (3 * aligned_batch_size + 1) * static_cast(sizeof(int)); DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); @@ -69,10 +73,12 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens, .aligned_batch_size = aligned_batch_size, .split_kv = split_kv, .num_sms = num_sms, + .is_varlen = is_varlen, .batch_size = batch_size, .next_n = next_n, .is_context_lens_2d = is_context_lens_2d, .context_lens = context_lens.data_ptr(), + .indices = const_cast(indices_ptr), .schedule_metadata = schedule_metadata.data_ptr(), .launch_args = LaunchArgs(1, num_threads, smem_size) }; @@ -90,6 +96,7 @@ public: int head_dim; int block_kv; bool is_context_lens_2d; + bool is_varlen; int block_table_stride; int logits_stride; @@ -100,6 +107,7 @@ public: int* context_lens; void* logits; int* block_table; + int* indices; int* schedule_meta; CUtensorMap tensor_map_q; @@ -129,7 +137,7 @@ static void __instantiate_kernel() {{ auto ptr = reinterpret_cast(&sm{}_fp8_paged_mqa_logits< {}, {}, {}, {}, - {}, + {}, {}, {}, {}, {}, {}, {}, @@ -139,7 +147,7 @@ static void __instantiate_kernel() {{ )", arch, arch, args.next_n, args.num_heads, args.head_dim, args.block_kv, - args.is_context_lens_2d, + args.is_context_lens_2d, args.is_varlen ? "true" : "false", args.num_q_stages, args.num_kv_stages, args.split_kv, args.num_specialized_threads, args.num_math_threads, @@ -151,7 +159,7 @@ static void __instantiate_kernel() {{ args.batch_size, args.logits_stride, args.block_table_stride, args.context_lens, args.logits, - args.block_table, args.schedule_meta, + args.block_table, args.indices, args.schedule_meta, args.tensor_map_q, args.tensor_map_kv, args.tensor_map_kv_scales, args.tensor_map_weights )); @@ -165,12 +173,14 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, const torch::Tensor& context_lens, const torch::Tensor& logits, const torch::Tensor& block_table, + const torch::Tensor& indices, const torch::Tensor& schedule_meta, const at::ScalarType& logits_dtype, const int& batch_size, const int& next_n, const int& num_heads, const int& head_dim, const int& num_kv_blocks, const int& block_kv, const bool& is_context_lens_2d, + const bool& is_varlen, const int& logits_stride, const int& block_table_stride, const int& num_sms, @@ -183,7 +193,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0); // Construct TMAs - const int next_n_atom = (next_n % 2 == 0) ? 2 : 1; + const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1; const auto tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads, head_dim, next_n_atom * num_heads, static_cast(q.stride(2)), @@ -245,6 +255,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, .head_dim = head_dim, .block_kv = block_kv, .is_context_lens_2d = is_context_lens_2d, + .is_varlen = is_varlen, .block_table_stride = block_table_stride, .logits_stride = logits_stride, .num_q_stages = num_q_stages, @@ -253,6 +264,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, .context_lens = context_lens.data_ptr(), .logits = logits.data_ptr(), .block_table = block_table.data_ptr(), + .indices = is_varlen ? indices.data_ptr() : nullptr, .schedule_meta = schedule_meta.data_ptr(), .tensor_map_q = tensor_map_q, .tensor_map_kv = tensor_map_kv, @@ -279,6 +291,7 @@ public: int head_dim; int block_kv; bool is_context_lens_2d; + bool is_varlen; int block_table_stride; int logits_stride; @@ -289,6 +302,7 @@ public: int* context_lens; void* logits; int* block_table; + int* indices; int* schedule_meta; CUtensorMap tensor_map_q; @@ -314,7 +328,7 @@ static void __instantiate_kernel() {{ auto ptr = reinterpret_cast(&sm100_fp4_paged_mqa_logits< {}, {}, {}, {}, - {}, + {}, {}, {}, {}, {}, {}, {}, @@ -323,7 +337,7 @@ static void __instantiate_kernel() {{ }}; )", args.next_n, args.num_heads, args.head_dim, args.block_kv, - args.is_context_lens_2d, + args.is_context_lens_2d, args.is_varlen ? "true" : "false", args.num_q_stages, args.num_kv_stages, args.split_kv, args.num_specialized_threads, args.num_math_threads, @@ -335,7 +349,7 @@ static void __instantiate_kernel() {{ args.batch_size, args.logits_stride, args.block_table_stride, args.context_lens, args.logits, - args.block_table, args.schedule_meta, + args.block_table, args.indices, args.schedule_meta, args.tensor_map_q, args.tensor_map_sf_q, args.tensor_map_kv, args.tensor_map_sf_kv, args.tensor_map_weights @@ -351,12 +365,14 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q, const torch::Tensor& context_lens, const torch::Tensor& logits, const torch::Tensor& block_table, + const torch::Tensor& indices, const torch::Tensor& schedule_meta, const at::ScalarType& logits_dtype, const int& batch_size, const int& next_n, const int& num_heads, const int& head_dim, const int& num_kv_blocks, const int& block_kv, const bool& is_context_lens_2d, + const bool& is_varlen, const int& logits_stride, const int& block_table_stride, const int& num_sms, @@ -366,8 +382,8 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q, DG_HOST_ASSERT(split_kv == 256 and logits_stride % split_kv == 0); // TODO: tuning num_stages - const int num_q_stages = 3, num_kv_stages = 6, num_tmem_stages = 3; - const int next_n_atom = (next_n % 2 == 0) ? 2 : 1; + const int num_q_stages = 3, num_kv_stages = 10, num_tmem_stages = 3; + const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1; // `head_dim` must be 128 for 64B swizzling DG_HOST_ASSERT(head_dim == 128); @@ -416,6 +432,7 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q, .head_dim = head_dim, .block_kv = block_kv, .is_context_lens_2d = is_context_lens_2d, + .is_varlen = is_varlen, .block_table_stride = block_table_stride, .logits_stride = logits_stride, .num_q_stages = num_q_stages, @@ -424,6 +441,7 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q, .context_lens = context_lens.data_ptr(), .logits = logits.data_ptr(), .block_table = block_table.data_ptr(), + .indices = is_varlen ? indices.data_ptr() : nullptr, .schedule_meta = schedule_meta.data_ptr(), .tensor_map_q = tensor_map_q, .tensor_map_sf_q = tensor_map_sf_q, diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index a36122a..a9542e2 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -123,4 +123,4 @@ _C.init( _find_cuda_home() # CUDA home ) -__version__ = '2.4.2' +__version__ = '2.5.0' diff --git a/deep_gemm/include/deep_gemm/comm/barrier.cuh b/deep_gemm/include/deep_gemm/comm/barrier.cuh index 8cc9263..eb9858d 100644 --- a/deep_gemm/include/deep_gemm/comm/barrier.cuh +++ b/deep_gemm/include/deep_gemm/comm/barrier.cuh @@ -1,11 +1,20 @@ #pragma once +#include + #include #include #include namespace deep_gemm::comm { +CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() { + // Perform cluster_sync with `barrier.cluster.arrive.relaxed` + // This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); +} + template CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace, const uint32_t& sm_idx, const uint32_t& thread_idx, diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh index 3d92935..b8a99fd 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh @@ -401,7 +401,8 @@ void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, for (uint32_t i = 0; i < BLOCK_Q; ++ i) { // Load accumulator from TMEM uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; - tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); // Release TMEM empty if (i == BLOCK_Q - 1) { diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh index 1bf0025..d9add53 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh @@ -20,7 +20,7 @@ namespace deep_gemm { template = 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); // UMMA configs static constexpr uint32_t kNumTmemStages = 3; @@ -157,7 +158,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, // Scheduler constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; - using Scheduler = sched::PagedMQALogitsScheduler; + using Scheduler = sched::PagedMQALogitsScheduler; DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); // Make Q, KV and TMEM pipeline @@ -182,7 +183,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, cutlass::arch::warpgroup_reg_dealloc(); if (cute::elect_one_sync()) { - auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); // Persistently schedule over blocks // Initialize outside valid range to indicate no previous task @@ -196,11 +197,12 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); // Issue TMA Q + const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx); cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast(full_q_barriers[q_stage_idx]), static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), - smem_q[q_stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads); - tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_atom_idx * kNextNAtom); - tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_atom_idx * kNextNAtom); + smem_q[q_stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx); + tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx); full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE); } last_q_atom_idx = q_atom_idx; @@ -210,7 +212,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, } else if (warp_idx == kSpecWarpStart + 1) { // TMA warp for loading KV cache cutlass::arch::warpgroup_reg_dealloc(); - auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); // Persistently schedule over blocks uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage; @@ -225,10 +227,11 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, // Coalesced load of block table if (kv_block_idx_ptr == 32) { kv_block_idx_ptr = 0; - const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast(block_table_stride); + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); kv_block_idx_storage = (kv_idx + lane_idx < num_kv) ? block_table[block_table_offset + kv_idx + lane_idx] : 0; } + __syncwarp(); // Broadcast KV block indices int kv_block_idx[kNumBlocksPerSplit]; @@ -240,7 +243,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, // Wait KV consumer release CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); - + // Issue TMA KV if (cute::elect_one_sync()) { empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); @@ -260,7 +263,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, } else if (warp_idx == kSpecWarpStart + 2) { // UMMA warp cutlass::arch::warpgroup_reg_dealloc(); - auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // UTCCP transposer @@ -371,7 +374,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, } else if (warp_idx < kSpecWarpStart) { // Math warpgroups for reduce cutlass::arch::warpgroup_reg_alloc(); - auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); const auto math_warpgroup_idx = warpgroup_idx; const auto math_thread_idx = warp_idx * 32 + lane_idx; @@ -400,6 +403,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, // Persistently schedule over blocks uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; uint32_t q_atom_idx, kv_idx, _; + bool is_paired_atom = false; while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) { if (q_atom_idx != last_q_atom_idx) { CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); @@ -423,11 +427,16 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, weights[i][j + 3] = raw.w; } } + + // Check if this atom pairs two tokens from the same sequence + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2); + } } last_q_atom_idx = q_atom_idx; // Calculate KV offset in advance - auto kv_offset = q_atom_idx * kNextNAtom * static_cast(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx; + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx; // Advance pipeline by `kNumMathWarpGroups` steps // Wait UMMA arrival @@ -436,53 +445,58 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, ptx::tcgen05_after_thread_sync(); // Reduce over the head dim and store - #pragma unroll - for (uint32_t i = 0; i < kNextNAtom; ++ i) { - // Load accumulator from TMEM - uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; - tmem_load(cute::Int{}, tmem_addr, accum); + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; + + // Only loop over valid iterations + #pragma unroll + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(sum.x + sum.y); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride)] = result; + __syncwarp(); + } // Release TMEM empty - if (i == kNextNAtom - 1) { - ptx::tcgen05_before_thread_sync(); - empty_tmem_barriers[tmem_stage_idx]->arrive(); - } + ptx::tcgen05_before_thread_sync(); + empty_tmem_barriers[tmem_stage_idx]->arrive(); + }; - // Accumulate weighted ReLU in parallel - auto sum_0 = make_float2(0, 0); - auto sum_1 = make_float2(0, 0); - - const auto transform = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(weights[i][j], weights[i][j + 1]); - return __ffma2_rn(a, b, sum); - }; - - #pragma unroll - for (uint32_t j = 0; j < kNumHeads; j += 4) { - sum_0 = transform(j, sum_0); - sum_1 = transform(j + 2, sum_1); - } - - auto sum = __fadd2_rn(sum_0, sum_1); - auto result = static_cast(sum.x + sum.y); - - // Store into the global memory - const auto dst_offset = kv_offset + i * static_cast(logits_stride); - if constexpr(sizeof(logits_dtype_t) == 2) { - // Pack two adjacent bf16 lanes into uint32 for wider store - uint16_t my_bits = *reinterpret_cast(&result); - uint16_t neighbor_bits = __shfl_down_sync(0xffffffff, my_bits, 1); - uint32_t packed; - asm volatile("mov.b32 %0, {%1, %2};" : "=r"(packed) : "h"(my_bits), "h"(neighbor_bits)); - if (lane_idx % 2 == 0) - *reinterpret_cast(logits + dst_offset) = packed; - } else { - logits[dst_offset] = result; - } - // this sync warp prevent the next load tmem from reordering - // nvcc may reorder it to overlap with the current tmem load, lead to large register usage - __syncwarp(); + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); } } diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh index 58fb343..b2adc6c 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh @@ -48,6 +48,7 @@ template < > CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void sm100_fp8_fp4_mega_moe_impl(void* y, + int* cumulative_local_expert_recv_stats, const uint32_t num_tokens, const __grid_constant__ layout::SymBuffer sym_buffer, const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, @@ -91,7 +92,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y, // Workspaces const auto workspace = layout::Workspace( - sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk, BLOCK_M); + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); // Token and buffer layouts constexpr auto fp8_token_layout = layout::Data(kHidden); @@ -170,7 +171,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y, constexpr uint32_t UMMA_K = 32; constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; - DG_STATIC_ASSERT(BLOCK_M % 32 == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M"); DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N"); DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); @@ -269,7 +270,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y, auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2); // A cluster sync is essential for 2CTA tensor memory allocation - cute::cluster_sync(); + comm::cluster_sync_with_relaxed_arrive(); // Initialization if (warp_idx == 0) { @@ -307,7 +308,9 @@ sm100_fp8_fp4_mega_moe_impl(void* y, // Allocate tensor memory Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); } - cute::cluster_sync(); + // NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`, + // and `barrier.cluster.wait.aligned` is by default `.acquire` + comm::cluster_sync_with_relaxed_arrive(); // Task scheduler auto scheduler = sched::MegaMoEScheduler< @@ -599,7 +602,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y, __syncwarp(); } - // Clean workspace for the next usage + // Clean workspace for the next usage, and also do cumulative stats // NOTES: it is overlapped with combine reduction epilogue ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); @@ -623,19 +626,27 @@ sm100_fp8_fp4_mega_moe_impl(void* y, // Wait read count ready ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); - // Clean expert token count - if (thread_idx == 0) + // Clean expert token count, and add cumulative results + DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); + if (warp_idx == 0) { *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } // Clean per-rank token count for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); // Clean L1 and L2 arrival stuffs for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; } + __syncwarp(); } } @@ -672,23 +683,22 @@ sm100_fp8_fp4_mega_moe_impl(void* y, const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); const auto expected = scheduler.template get_valid_m(); while (ptx::ld_acq(ptr) != expected); + } else { + // The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival + // NOTES: Originally we wait blocks on-demand to overlap L1 calculation + // with L2, but this optimization is negative when `num_experts_per_wave` + // guarantees L1's completion when L2 starts. So we remove it. + // In the future, if `num_experts_per_wave` is not large enough + // due to small `num_experts_per_rank`, we may need to add it back or add a switch + DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes"); + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + // NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts + // to avoid undefined behavior when `num_k_blocks == 32` + const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1; + while (ptx::ld_acq_gpu(ptr) != expected); } - uint64_t cached_l2_arrival_mask = 0; for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { - // Wait current K block arrival - if (block_phase == sched::BlockPhase::Linear2) { - // The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2 L1 blocks' arrival - DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes"); - const uint64_t needed = 3ull << (k_block_idx * 2); - if ((cached_l2_arrival_mask & needed) != needed) { - const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); - do { - cached_l2_arrival_mask = ptx::ld_acq_gpu(ptr); - } while ((cached_l2_arrival_mask & needed) != needed); - } - } - // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); @@ -953,8 +963,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y, // Load weights from global into register cache per 32 tokens DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size"); - DG_STATIC_ASSERT(WG_BLOCK_M % 32 == 0, "Invalid block size"); - if ((j * ATOM_M) % 32 == 0) { + if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) { stored_cached_weight = *l1_topk_weights_buffer .get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx) .get_base_ptr(); @@ -1060,19 +1069,26 @@ sm100_fp8_fp4_mega_moe_impl(void* y, // Only one warp per pair writes (both hold the same SF after cross-warp reduce) // Each lane < 4 holds SF for 2 rows (sf.x and sf.y) if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) { - // TODO: I believe the expression can be optimized - const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M - + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2; const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2; const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4; const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t); const auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); // NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4 + // NOTES: originally there was: + // - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2 + // - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)` + // We find out that + // 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside + // 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside + // This reduce the number of computation instructions. + const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + __builtin_assume(token_base_idx < BLOCK_M); const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M - + transform_sf_token_idx(token_idx_in_expert); - sf_base_ptr[k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx] = + + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; + const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; + sf_base_ptr[sf_addr] = (*reinterpret_cast(&sf.x) >> 23); - sf_base_ptr[k_uint_idx * mn_stride + (sf_pool_token_idx + 4) * static_cast(sizeof(uint32_t)) + byte_idx] = + sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = (*reinterpret_cast(&sf.y) >> 23); } __syncwarp(); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh index dd38d93..9a5bddb 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -20,7 +20,7 @@ namespace deep_gemm { template = 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); // Shared memory configs static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; @@ -136,7 +137,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Scheduler constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; - using Scheduler = sched::PagedMQALogitsScheduler; + using Scheduler = sched::PagedMQALogitsScheduler; DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); // Q and KV pipeline @@ -157,13 +158,14 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, if (warp_idx == kSpecWarpStart) { // TMA warp for loading data cutlass::arch::warpgroup_reg_dealloc(); - auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); uint32_t q_iter_idx = 0, kv_iter_idx = 0; - const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_atom_idx) { + const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) { if (cute::elect_one_sync()) { - tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads); - tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_atom_idx * kNextNAtom); + const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; @@ -182,7 +184,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, while (fetched_next_task) { // Prefetch next Q when (q, atom) changes - bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + 1); + const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size); + bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance); if (q_atom_idx != next_q_atom_idx) kv_block_idx_ptr = 32; @@ -195,17 +198,18 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // TODO(xuzhean): consider -1 if (kv_block_idx_ptr == 32) { kv_block_idx_ptr = 0; - const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast(block_table_stride); + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); kv_block_idx_storage = (kv_idx + lane_idx < num_kv) ? block_table[block_table_offset + kv_idx + lane_idx] : 0; } + __syncwarp(); DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); // Wait Q consumer release and issue TMA Q if (prefetch_q) { CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); - issue_tma_q(q_stage_idx, q_atom_idx + 1); + issue_tma_q(q_stage_idx, q_atom_idx + next_advance); } uint32_t kv_block_idx[kNumBlocksPerSplit]; @@ -236,7 +240,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, } } else if (warp_idx == kSpecWarpStart + 1) { cutlass::arch::warpgroup_reg_dealloc(); - auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); uint32_t q_iter_idx = 0, kv_iter_idx = 0; // Require full allocation @@ -292,7 +296,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, } else if (warp_idx < kSpecWarpStart) { // Math warpgroups for reduce cutlass::arch::warpgroup_reg_alloc(); - auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); uint32_t q_iter_idx = 0, kv_iter_idx = 0; // Offsets @@ -321,6 +325,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; uint32_t q_stage_idx, q_phase; uint32_t umma_phase = 0; + bool is_paired_atom = false; while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) { // Q or atom changes @@ -340,6 +345,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, for (uint32_t j = 0; j < kNumHeads; ++ j) weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); } + + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2); + } } // Get current task indices @@ -347,7 +356,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_atom_idx * kNextNAtom * static_cast(logits_stride) + kv_idx * BLOCK_KV; + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV; // Wait TMA KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); @@ -367,40 +376,56 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Reduce over the head dim and store DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); - #pragma unroll - for (uint32_t i = 0; i < kNextNAtom; ++ i) { - // Load accumulator from TMEM + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; float accum[kNumHeads]; - tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); - - // Release TMEM empty - if (i == kNextNAtom - 1) { - ptx::tcgen05_before_thread_sync(); - empty_umma_barriers[math_warpgroup_idx]->arrive(); - } - - // Accumulate weighted ReLU in parallel - auto sum_0 = make_float2(0, 0); - auto sum_1 = make_float2(0, 0); - - const auto transform = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(weights[i][j], weights[i][j + 1]); - return __ffma2_rn(a, b, sum); - }; #pragma unroll - for (uint32_t j = 0; j < kNumHeads; j += 4) { - sum_0 = transform(j, sum_0); - sum_1 = transform(j + 2, sum_1); + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(scale_kv * (sum.x + sum.y)); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride) + math_thread_idx] = result; + __syncwarp(); } - auto sum = __fadd2_rn(sum_0, sum_1); - auto result = static_cast(scale_kv * (sum.x + sum.y)); + // Release TMEM empty + ptx::tcgen05_before_thread_sync(); + empty_umma_barriers[math_warpgroup_idx]->arrive(); + }; - // Store into the global memory - logits[kv_offset + i * static_cast(logits_stride) + math_thread_idx] = result; - __syncwarp(); + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); } } diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh index 0610a2e..cc2592b 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -21,7 +21,7 @@ namespace deep_gemm { template ::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; @@ -132,8 +135,8 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, cudaGridDependencySynchronize(); // Scheduler - auto scheduler = sched::PagedMQALogitsScheduler( - blockIdx.x, context_lens, schedule_meta); + auto scheduler = sched::PagedMQALogitsScheduler( + blockIdx.x, batch_size, context_lens, schedule_meta, indices); DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); // Q and KV pipeline diff --git a/deep_gemm/include/deep_gemm/layout/mega_moe.cuh b/deep_gemm/include/deep_gemm/layout/mega_moe.cuh index fe3c416..13520c6 100644 --- a/deep_gemm/include/deep_gemm/layout/mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/layout/mega_moe.cuh @@ -1,19 +1,27 @@ #pragma once +#include + #include #include namespace deep_gemm::layout { -// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding +static constexpr int kNumCandidateBlockMs = 7; +static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192}; +static constexpr int kMaxCandidateBlockM = 192; +static constexpr int kMinCandidateBlockM = 8; +static constexpr int kLCMCandidateBlockM = 384; + +// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M template CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk, - T num_experts_per_rank, T block_m) { + T num_experts_per_rank) { const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank; const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank); return math::constexpr_align( - num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (block_m - 1), - block_m); + num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast(kMaxCandidateBlockM) - 1), + static_cast(kLCMCandidateBlockM)); } // SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M @@ -48,17 +56,14 @@ struct Workspace { const uint32_t& num_ranks, const uint32_t& num_experts, const uint32_t& num_max_tokens_per_rank, - const uint32_t& num_topk, - const uint32_t& block_m): + const uint32_t& num_topk): base(base), num_ranks(num_ranks), num_experts(num_experts), num_max_tokens_per_rank(num_max_tokens_per_rank) { num_experts_per_rank = num_experts / num_ranks; num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank; - num_max_pool_tokens = get_num_max_pool_tokens( - num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank, block_m); - num_max_pool_blocks = num_max_pool_tokens / block_m; - DG_UNIFIED_ASSERT(num_max_tokens_per_rank % block_m == 0); + num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); + num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM; } CUTLASS_HOST_DEVICE diff --git a/deep_gemm/include/deep_gemm/ptx/ld_st.cuh b/deep_gemm/include/deep_gemm/ptx/ld_st.cuh index 6a64b9c..c3e03be 100644 --- a/deep_gemm/include/deep_gemm/ptx/ld_st.cuh +++ b/deep_gemm/include/deep_gemm/ptx/ld_st.cuh @@ -164,7 +164,7 @@ CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) { } CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) { - asm volatile("st.L1::no_allocate.relaxed.sys.u64 [%0], %1;" :: "l"(ptr), "l"(value)); + asm volatile("st.L1::no_allocate.relaxed.sys.global.u64 [%0], %1;" :: "l"(ptr), "l"(value)); } /// Atomics @@ -186,7 +186,11 @@ CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& valu return ret; } -__forceinline__ __device__ void red_add(const uint32_t* ptr, const uint32_t& value) { +CUTLASS_DEVICE void red_add(const int* ptr, const int& value) { + asm volatile("red.gpu.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_add(const uint32_t* ptr, const uint32_t& value) { asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value)); } diff --git a/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh index 42f955f..548bbbc 100644 --- a/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh @@ -6,22 +6,51 @@ namespace deep_gemm::sched { -template +template CUTLASS_GLOBAL __launch_bounds__(32, 1) void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, - const uint32_t* context_lens, uint32_t* schedule_metadata) { + const uint32_t* context_lens, const uint32_t* indices, uint32_t* schedule_metadata) { DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); const uint32_t lane_idx = ptx::get_lane_idx(); // Wait for primary kernel completion cudaGridDependencySynchronize(); + __shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize]; + __shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize]; + __shared__ uint32_t varlen_num_atoms_shared; + uint32_t num_items; + + if constexpr (kIsVarlen) { + if (lane_idx == 0) { + uint32_t t = 0, atom_count = 0; + while (t < batch_size) { + varlen_atom_token_start[atom_count] = t; + const bool is_paired = (t + 1 < batch_size and indices[t] == indices[t + 1]); + varlen_atom_context_len[atom_count] = is_paired ? context_lens[t + 1] : context_lens[t]; + t += is_paired ? 2 : 1; + ++ atom_count; + } + varlen_num_atoms_shared = atom_count; + } + __syncwarp(); + num_items = varlen_num_atoms_shared; + } else { + num_items = batch_size; + } + + // Compute num_segs and prefix sum uint32_t num_segs[kAlignedBatchSize / 32]; #pragma unroll for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { const uint32_t q_idx = k * 32 + lane_idx; - const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); - const uint32_t context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0); + uint32_t context_len; + if constexpr (kIsVarlen) { + context_len = (q_idx < num_items ? varlen_atom_context_len[q_idx] : 0); + } else { + const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); + context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0); + } num_segs[k] = math::ceil_div(context_len, SPLIT_KV); } @@ -40,44 +69,118 @@ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t ne sum = __shfl_sync(0xffffffff, x, 31); } - const uint32_t num_next_n_atoms = next_n / ((next_n % 2 == 0) ? 2 : 1); - const uint32_t total = sum * num_next_n_atoms; - const uint32_t q = total / kNumSMs, r = total % kNumSMs; - for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { - uint32_t seg_starts = sm_idx * q + min(sm_idx, r); - uint32_t q_idx = 0; - while (q_idx < batch_size and prefix_sum[q_idx] * num_next_n_atoms <= seg_starts) - ++ q_idx; - const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms); - const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]); - const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0; - const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0; - const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx; - __syncwarp(); + // SM work distribution + if constexpr (kIsVarlen) { + const uint32_t total = sum; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = num_items; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t atom_idx = lo; + const uint32_t kv_split_idx = (atom_idx == 0 ? seg_starts : seg_starts - prefix_sum[atom_idx - 1]); + const uint32_t q_atom_idx = (atom_idx < num_items ? varlen_atom_token_start[atom_idx] : batch_size); + __syncwarp(); - schedule_metadata[sm_idx * 2] = q_atom_idx; - schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } + } else { + const uint32_t next_n_atom = (next_n >= 2) ? 2 : 1; + const uint32_t num_next_n_atoms = math::ceil_div(next_n, next_n_atom); + const uint32_t total = sum * num_next_n_atoms; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = batch_size; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] * num_next_n_atoms <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t q_idx = lo; + const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms); + const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]); + const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0; + const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0; + const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx; + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } } } -template +struct IndicesStorage { + const uint32_t* indices; +}; + +template <> +struct IndicesStorage {}; + +template -struct PagedMQALogitsScheduler { +struct PagedMQALogitsScheduler : IndicesStorage { const uint32_t* context_lens; + uint32_t batch_size; uint32_t current_q_atom_idx, current_kv_idx; uint32_t end_q_atom_idx, end_kv_idx; uint32_t current_num_kv; - CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const { - const uint32_t q_idx = q_atom_idx / kNumNextNAtoms; - const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); - return math::ceil_div(context_lens[lens_idx], BLOCK_KV); + CUTLASS_DEVICE static uint32_t atom_to_token_idx(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + if constexpr (kPadOddN) { + return q_atom_idx / kNumNextNAtoms * kNextN + q_atom_idx % kNumNextNAtoms * kNextNAtom; + } else { + return q_atom_idx * kNextNAtom; + } + } } - CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t* context_lens, const uint32_t* schedule_meta) { + CUTLASS_DEVICE static uint32_t atom_to_block_table_row(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + return q_atom_idx / kNumNextNAtoms; + } + } + + CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + const bool is_paired = (q_atom_idx + 1 < batch_size and + this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]); + const uint32_t ctx_len = is_paired ? context_lens[q_atom_idx + 1] : context_lens[q_atom_idx]; + return math::ceil_div(ctx_len, BLOCK_KV); + } else { + const uint32_t q_idx = q_atom_idx / kNumNextNAtoms; + const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); + return math::ceil_div(context_lens[lens_idx], BLOCK_KV); + } + } + + CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t& batch_size, + const uint32_t* context_lens, + const uint32_t* schedule_meta, const uint32_t* indices) { this->context_lens = context_lens; + this->batch_size = batch_size; + if constexpr (kIsVarlen) { + this->indices = indices; + } const auto current_pack = reinterpret_cast(schedule_meta)[sm_idx]; const auto end_pack = reinterpret_cast(schedule_meta)[sm_idx + 1]; @@ -87,6 +190,28 @@ struct PagedMQALogitsScheduler { current_num_kv = get_num_kv(current_q_atom_idx); } + // Advance step in q_atom_idx space when moving to the next atom. + // Varlen: 1 or 2 depending on whether consecutive tokens share the same sequence. + // Non-varlen: always 1 (one atom unit). + CUTLASS_DEVICE uint32_t get_atom_advance(const uint32_t& q_atom_idx, const uint32_t& bound) const { + if constexpr (kIsVarlen) { + return (q_atom_idx + 1 < bound and this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]) ? 2 : 1; + } else { + return 1; + } + } + + // Whether num_kv should be refreshed after advancing to q_atom_idx. + // Varlen: always refresh (each atom may have a different context_len). + // Non-varlen: only at atom-group boundaries (atoms within a group share context_len). + CUTLASS_DEVICE bool should_refresh_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + return true; + } else { + return q_atom_idx % kNumNextNAtoms == 0; + } + } + CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) { q_atom_idx = current_q_atom_idx; kv_idx = current_kv_idx; @@ -97,9 +222,9 @@ struct PagedMQALogitsScheduler { current_kv_idx += kNumBlocksPerSplit; if (current_kv_idx >= current_num_kv) { - ++ current_q_atom_idx; current_kv_idx = 0; - if (current_q_atom_idx % kNumNextNAtoms == 0 and exist_q_atom_idx(current_q_atom_idx)) { + current_q_atom_idx += get_atom_advance(current_q_atom_idx, end_q_atom_idx); + if (should_refresh_num_kv(current_q_atom_idx) and exist_q_atom_idx(current_q_atom_idx)) { current_num_kv = get_num_kv(current_q_atom_idx); } } diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 2010bad..e624ecf 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -61,10 +61,8 @@ def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup, hidden: int, intermediate_hidden: int, use_fp8_dispatch: bool = True, activation: str = 'swiglu') -> SymmBuffer: - # Token count must be aligned to block m - num_ranks = group.size() - block_m = _C.get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk) - num_max_tokens_per_rank = align(num_max_tokens_per_rank, block_m) + # Token count must be aligned to block sizes + num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe()) return SymmBuffer( group, num_experts, @@ -111,6 +109,7 @@ def fp8_fp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], sym_buffer: SymmBuffer, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, recipe: Tuple[int, int, int] = (1, 1, 32), activation: str = 'swiglu', activation_clamp: Optional[float] = None, @@ -118,6 +117,7 @@ def fp8_fp4_mega_moe(y: torch.Tensor, _C.fp8_fp4_mega_moe( y, l1_weights, l2_weights, + cumulative_local_expert_recv_stats, sym_buffer.buffer, sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), sym_buffer.num_max_tokens_per_rank, diff --git a/scripts/quick_plot_pm.py b/scripts/quick_plot_pm.py new file mode 100644 index 0000000..3aee8b8 --- /dev/null +++ b/scripts/quick_plot_pm.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +"""Plot a curated set of NCU PM metrics from an .ncu-rep report. + +Usage: + python scripts/quick_plot_pm.py [report.ncu-rep] + +By default the script saves a PNG next to the report. +With --interactive, it opens a Qt window instead. +""" + +import argparse +import csv +import io +import subprocess +from dataclasses import dataclass + +import matplotlib +import numpy as np + + +@dataclass(frozen=True) +class MetricSpec: + name: str + metric: str + kind: str + category: str + aliases: tuple[str, ...] = () + + +@dataclass(frozen=True) +class ResolvedMetricSpec: + name: str + metric: str + kind: str + category: str + + +@dataclass(frozen=True) +class MetricSeries: + name: str + metric: str + category: str + unit: str + values: tuple[float, ...] + + +CATEGORY_ORDER = [ + "Overview", + "SM", + "L1", + "L2", + "DRAM", + "Interconnect", +] + + +KIND_SUFFIXES = { + "pct_peak": [".avg.pct_of_peak_sustained_elapsed"], + "pct": [".pct", ".avg.pct_of_peak_sustained_elapsed"], + "avg": [".avg"], + "sum": [".sum"], + "avg_per_second": [".avg.per_second"], + "sum_per_second": [".sum.per_second"], + "avg_per_cycle_active": [".avg.per_cycle_active"], + "avg_per_cycle_elapsed": [".avg.per_cycle_elapsed"], + "sum_per_cycle_elapsed": [".sum.per_cycle_elapsed"], +} + + +# Curated from scripts/ncu-metrics.txt, with a few corrections against +# `ncu --query-metrics --chip gb100`: +# - Blocks launched uses `gr__ctas_launched_realtime` +# - SM active cycles uses `sm__cycles_active` +# - L2 throughput for GCC requests uses `lts__t_sector_throughput_srcunit_gcc` +# - C2C throughput uses `ctc__throughput` +# - NVLink RX metrics use the `NVLRX` domain +CURATED_METRICS = [ + MetricSpec("Blocks Launched", "FE_B.TriageCompute.gr__ctas_launched_realtime", "sum_per_cycle_elapsed", "Overview"), + MetricSpec("Average Blocks Active", "TPC.TriageCompute.tpc__ctas_active_realtime", "avg_per_cycle_elapsed", "Overview"), + MetricSpec("Total Blocks Active", "TPC.TriageCompute.tpc__ctas_active_realtime", "sum_per_cycle_elapsed", "Overview"), + MetricSpec("Average CGAs Active", "GPC_B.TriageCompute.gpc__cgas_active_realtime", "avg_per_cycle_elapsed", "Overview"), + MetricSpec("Total CGAs Active", "GPC_B.TriageCompute.gpc__cgas_active_realtime", "sum_per_cycle_elapsed", "Overview"), + MetricSpec("SM Active Cycles", "TPC.TriageCompute.sm__cycles_active", "avg", "SM"), + MetricSpec("Executed IPC Active", "TPC.TriageCompute.sm__inst_executed_realtime", "avg_per_cycle_active", "SM"), + MetricSpec("Executed IPC Elapsed", "TPC.TriageCompute.sm__inst_executed_realtime", "avg_per_cycle_elapsed", "SM"), + MetricSpec("SM Throughput", "TPC.TriageCompute.sm__inst_executed_realtime", "pct_peak", "SM"), + MetricSpec("SM ALU Pipe Throughput", "TPC.TriageCompute.sm__inst_executed_pipe_alu_realtime", "pct_peak", "SM"), + MetricSpec("SM FMA Pipe Throughput", "TPC.TriageCompute.sm__pipe_fma_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM FMA Heavy Pipe Throughput", "TPC.TriageCompute.sm__pipe_fmaheavy_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM FMA Light Pipe Throughput", "TPC.TriageCompute.sm__pipe_fmalite_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM Tensor Pipe Throughput", "TPC.TriageCompute.sm__pipe_tensor_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM TMEM Pipe Throughput", "SM_A.TriageCompute.sm__mem_tensor_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM Uniform Pipe Throughput", "SM_A.TriageCompute.sm__inst_executed_pipe_uniform_realtime", "pct_peak", "SM"), + MetricSpec("SM XU Pipe Throughput", "SM_A.TriageCompute.sm__inst_executed_pipe_xu_realtime", "pct_peak", "SM"), + MetricSpec("L1 Throughput", "SM_A.TriageCompute.l1tex__throughput", "pct_peak", "L1"), + MetricSpec("L1 Sectors", "SM_B.TriageCompute.l1tex__t_sectors", "sum", "L1"), + MetricSpec("L1 Hit Rate", "SM_B.TriageCompute.l1tex__t_sector_hit_rate", "pct", "L1"), + MetricSpec("L1 Lookup Hit", "SM_B.TriageCompute.l1tex__t_sectors_lookup_hit", "sum", "L1"), + MetricSpec("L1 Lookup Miss", "SM_B.TriageCompute.l1tex__t_sectors_lookup_miss", "sum", "L1"), + MetricSpec("L1 Wavefronts (Data)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts", "avg", "L1"), + MetricSpec("L1 Wavefronts (LGDS)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts_mem_lgds", "avg", "L1"), + MetricSpec("L1 Wavefronts (Shared)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts_mem_shared", "avg", "L1"), + MetricSpec("L2 Throughput", "LTS.TriageCompute.lts__throughput", "pct_peak", "L2"), + MetricSpec("L2 Throughput for L1 Requests", "LTS.TriageCompute.lts__t_sector_throughput_srcunit_tex", "pct_peak", "L2"), + MetricSpec("L2 Throughput for GCC Requests", "LTS.TriageCompute.lts__t_sector_throughput_srcunit_gcc", "pct_peak", "L2"), + MetricSpec("L2 Throughput to DRAM", "LTS.TriageCompute.lts__t_sector_throughput_srcnode_fbp", "pct_peak", "L2"), + MetricSpec("SysL2 Throughput to Peer Memory", "SYSLTS.TriageCompute.syslts__t_sector_throughput_aperture_peer", "pct_peak", "L2"), + MetricSpec("SysL2 Throughput to System Memory", "SYSLTS.TriageCompute.syslts__t_sector_throughput_aperture_sysmem", "pct_peak", "L2"), + MetricSpec("L2 Hit Rate", "LTS.TriageCompute.lts__average_t_sector_hit_rate_realtime", "pct", "L2"), + MetricSpec("L2 Hit Rate From L1", "LTS.TriageCompute.lts__average_t_sector_hit_rate_srcunit_tex_realtime", "pct", "L2"), + MetricSpec("DRAM Frequency", "FBSP.TriageCompute.dram__cycles_elapsed", "avg_per_second", "DRAM"), + MetricSpec("DRAM Throughput", "FBSP.TriageCompute.dram__throughput", "pct_peak", "DRAM"), + MetricSpec("DRAM Read Throughput", "FBSP.TriageCompute.dram__read_throughput", "pct_peak", "DRAM"), + MetricSpec("DRAM Write Throughput", "FBSP.TriageCompute.dram__write_throughput", "pct_peak", "DRAM"), + MetricSpec("C2C Throughput", "TriageCompute.ctc__throughput", "pct_peak", "Interconnect", aliases=("TriageCompute.ctx__throughput",)), + MetricSpec("NVLink Transmitted Throughput", "NVLTX.TriageCompute.nvltx__bytes", "pct_peak", "Interconnect"), + MetricSpec("NVLink Received Throughput", "NVLRX.TriageCompute.nvlrx__bytes", "pct_peak", "Interconnect"), + MetricSpec("NVLink Transmitted Bandwidth", "NVLTX.TriageCompute.nvltx__bytes", "sum_per_second", "Interconnect"), + MetricSpec("NVLink Received Bandwidth", "NVLRX.TriageCompute.nvlrx__bytes", "sum_per_second", "Interconnect"), + MetricSpec("PCIe Throughput", "PCI.TriageCompute.pcie__throughput", "pct_peak", "Interconnect"), + MetricSpec("PCIe Read Bandwidth", "PCI.TriageCompute.pcie__read_bytes", "sum_per_second", "Interconnect"), + MetricSpec("PCIe Write Bandwidth", "PCI.TriageCompute.pcie__write_bytes", "sum_per_second", "Interconnect"), +] + + +def _run_csv_command(command, timeout): + result = subprocess.run(command, capture_output=True, text=True, timeout=timeout) + if result.returncode != 0 and not result.stdout: + return None + reader = csv.reader(io.StringIO(result.stdout)) + return list(reader) + + +def _query_available_metrics(chip): + result = subprocess.run( + ["ncu", "--query-metrics", "--chip", chip], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + raise RuntimeError(result.stderr.strip() or f"failed to query metrics for chip {chip}") + + metrics = set() + for line in result.stdout.splitlines(): + parts = line.split() + if not parts: + continue + token = parts[0] + if "__" not in token: + continue + metrics.add(token) + return metrics + + +def _metric_candidates(metric): + candidates = [metric] + marker = ".TriageCompute." + if marker in metric: + candidates.append(metric.split(marker, 1)[1]) + return candidates + + +def resolve_metric_specs(chip): + available = _query_available_metrics(chip) + resolved = [] + missing = [] + + for spec in CURATED_METRICS: + candidates = [] + for metric in (spec.metric, *spec.aliases): + candidates.extend(_metric_candidates(metric)) + + actual_metric = next((metric for metric in candidates if metric in available), None) + if actual_metric is None: + missing.append(spec) + continue + + resolved.append(ResolvedMetricSpec(spec.name, actual_metric, spec.kind, spec.category)) + + return resolved, missing + + +def _parse_metric_values(raw): + if not raw or raw == "no data": + return () + + try: + if raw.startswith("(") and raw.endswith(")"): + rest = raw[1:-1] + return tuple(float(v.strip().replace(",", "")) for v in rest.split(";") if v.strip()) + if " (" in raw: + _agg, rest = raw.split(" (", 1) + rest = rest.rstrip(")") + return tuple(float(v.strip().replace(",", "")) for v in rest.split(";") if v.strip()) + return (float(raw.replace(",", "")),) + except ValueError: + return () + + +def _probe_metric_series(report, metric_name): + rows = _run_csv_command( + [ + "ncu", + "--import", + report, + "--page", + "raw", + "--csv", + "--metrics", + metric_name, + "--print-metric-instances", + "values", + ], + timeout=60, + ) + if not rows or len(rows) < 3 or len(rows[0]) <= 11: + return None + + header, units, row = rows[0], rows[1], rows[2] + unit = units[11] if len(units) > 11 else "" + raw = row[11] if len(row) > 11 else "" + values = _parse_metric_values(raw) + return header[11], unit, values + + +def collect_metric_series(report, resolved_specs): + collected = [] + skipped = [] + + for spec in resolved_specs: + series = None + for suffix in KIND_SUFFIXES[spec.kind]: + probe = _probe_metric_series(report, f"{spec.metric}{suffix}") + if probe is None: + continue + full_metric, unit, values = probe + if len(values) > 1: + series = MetricSeries(spec.name, full_metric, spec.category, unit, values) + break + + if series is None: + skipped.append(spec) + continue + + collected.append(series) + + return collected, skipped + + +def _format_value(value): + if value == 0: + return "0" + abs_value = abs(value) + if abs_value >= 1e12: + return f"{value / 1e12:.2f} T" + if abs_value >= 1e9: + return f"{value / 1e9:.2f} G" + if abs_value >= 1e6: + return f"{value / 1e6:.2f} M" + if abs_value >= 1e3: + return f"{value / 1e3:.2f} K" + if abs_value >= 1: + return f"{value:.1f}" + return f"{value:.2f}" + + +def _format_with_unit(value, unit): + if not unit: + return _format_value(value) + return f"{_format_value(value)} {unit}" + + +def plot_pm(report, metrics, save=False): + """Plot curated PM metrics as shared-x subplots in a light theme.""" + import matplotlib.pyplot as plt + from matplotlib.gridspec import GridSpec + + if not metrics: + print("No curated metrics had time-series data in the report.") + return + + bg_fig = "#ffffff" + bg_row = "#f6f8fb" + text_primary = "#1f2937" + text_secondary = "#6b7280" + text_header = "#111827" + grid_color = "#d7deea" + border = "#c7d0dd" + + wave_colors = { + "Overview": "#7c8aa5", + "SM": "#4f87c2", + "L1": "#2f9d8f", + "L2": "#dd8452", + "DRAM": "#c95d63", + "Interconnect": "#8c6bb1", + } + + category_rank = {category: index for index, category in enumerate(CATEGORY_ORDER)} + metrics = sorted(metrics, key=lambda item: (category_rank.get(item.category, 99), item.name)) + + row_h = 0.55 + label_w = 3.6 + plot_w = 14.0 + fig_w = label_w + plot_w + fig_h = row_h * len(metrics) + 0.6 + + fig = plt.figure(figsize=(fig_w, fig_h), facecolor=bg_fig) + gs = GridSpec( + len(metrics), + 1, + figure=fig, + left=label_w / fig_w, + right=0.97, + top=1 - 0.45 / fig_h, + bottom=0.35 / fig_h, + hspace=0.18, + ) + axes = [fig.add_subplot(gs[i, 0]) for i in range(len(metrics))] + + prev_category = None + for idx, metric in enumerate(metrics): + ax = axes[idx] + values = np.array(metric.values) + x = np.arange(len(values)) + wave_color = wave_colors.get(metric.category, "#5b9bd5") + + ax.set_facecolor(bg_row) + ax.fill_between(x, values, alpha=0.35, color=wave_color, linewidth=0) + ax.plot(x, values, linewidth=0.8, color=wave_color) + + ax.set_xlim(0, len(values) - 1) + if metric.unit == "%": + ax.set_ylim(0, 100) + else: + ymax = np.max(values) + ax.set_ylim(0, ymax * 1.15 if ymax > 0 else 1) + + ax.grid(True, axis="both", color=grid_color, linewidth=0.5, alpha=0.85) + ax.tick_params(axis="both", colors=text_secondary, labelsize=6, length=0) + + if idx < len(metrics) - 1: + ax.tick_params(axis="x", labelbottom=False) + else: + ax.set_xlabel("Sample Index", fontsize=8, color=text_secondary) + + ymin_v, ymax_v = ax.get_ylim() + ax.set_yticks([ymin_v, ymax_v]) + ax.set_yticklabels([_format_value(ymin_v), _format_value(ymax_v)], fontsize=6, color=text_secondary) + + peak = np.max(values) + ax.text( + 1.005, + 0.5, + _format_with_unit(peak, metric.unit), + transform=ax.transAxes, + fontsize=7, + color=text_secondary, + va="center", + ha="left", + family="monospace", + ) + + for spine in ax.spines.values(): + spine.set_color(border) + spine.set_linewidth(0.5) + + if metric.category != prev_category: + cat_y = ax.get_position().y1 + 0.008 + fig.text( + 0.005, + cat_y, + f" {metric.category}", + fontsize=8.5, + fontweight="bold", + color=text_header, + va="bottom", + family="sans-serif", + transform=fig.transFigure, + bbox=dict(boxstyle="square,pad=0.15", facecolor="#e9eef5", edgecolor="none"), + ) + prev_category = metric.category + + label_y = (ax.get_position().y0 + ax.get_position().y1) / 2 + fig.text( + label_w / fig_w - 0.012, + label_y, + metric.name, + fontsize=7.5, + color=text_primary, + va="center", + ha="right", + family="sans-serif", + transform=fig.transFigure, + ) + + fig.text( + 0.5, + 1 - 0.15 / fig_h, + f"PM Sampling - {report}", + fontsize=11, + fontweight="bold", + color=text_header, + ha="center", + va="top", + family="sans-serif", + transform=fig.transFigure, + ) + + if save: + out_path = report.replace(".ncu-rep", ".pm_sampling.png") + fig.savefig(out_path, dpi=150, facecolor=bg_fig, bbox_inches="tight", pad_inches=0.2) + print(f"Saved: {out_path}") + plt.close(fig) + else: + plt.show() + + +def main(): + parser = argparse.ArgumentParser(description="NCU PM Sampling plotter") + parser.add_argument("report", nargs="?", default="mega-moe-kk.3.ncu-rep", help="Path to .ncu-rep file") + parser.add_argument("--chip", default="gb100", help="Chip name used for `ncu --query-metrics`") + parser.add_argument("--interactive", action="store_true", help="Open an interactive Qt window instead of saving a PNG") + args = parser.parse_args() + + if args.interactive: + matplotlib.use("QtAgg") + else: + matplotlib.use("Agg") + + resolved_specs, missing_specs = resolve_metric_specs(args.chip) + if missing_specs: + print(f"Skipped {len(missing_specs)} curated metrics not available on {args.chip}.") + for spec in missing_specs: + print(f" missing: {spec.name} -> {spec.metric}") + + metric_series, skipped_specs = collect_metric_series(args.report, resolved_specs) + if skipped_specs: + print(f"Skipped {len(skipped_specs)} curated metrics with no time-series data in {args.report}.") + for spec in skipped_specs: + print(f" no series: {spec.name} -> {spec.metric}") + + plot_pm(args.report, metric_series, save=not args.interactive) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_ncu_mega_moe.sh b/scripts/run_ncu_mega_moe.sh new file mode 100755 index 0000000..4324575 --- /dev/null +++ b/scripts/run_ncu_mega_moe.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +set -e + +# parse num-processes, output_dir and separate python args +num_processes=8 +output_dir=work +python_args=() +for ((arg_idx = 1; arg_idx <= $#; ++arg_idx)); do + arg="${!arg_idx}" + case "$arg" in + --num-processes) + python_args+=("$arg") + if ((arg_idx < $#)); then + ((arg_idx++)) + num_processes="${!arg_idx}" + python_args+=("$num_processes") + fi + ;; + -h|--help) + echo "Usage: $0 [--num-processes N] [--output DIR] [python args...]" + exit 0 + ;; + --num-processes=*) + num_processes="${arg#*=}" + python_args+=("$arg") + ;; + -o|--output) + if ((arg_idx < $#)); then + ((arg_idx++)) + output_dir="${!arg_idx}" + fi + ;; + --output=*) + output_dir="${arg#*=}" + ;; + *) + python_args+=("$arg") + ;; + esac +done + +echo "Python Args: ${python_args[*]}" +echo "Num Processes: $num_processes" +echo "Output Dir: $output_dir" +mkdir -p $output_dir + +export DG_JIT_WITH_LINEINFO=1 # for source counters + +echo "Warm up JIT cache" +python tests/test_mega_moe.py --ncu-profile-only "${python_args[@]}" + +sleep 2 + +ncu_args=( + --config-file off + --force-overwrite + --kernel-name sm100_fp8_fp4_mega_moe_impl + --import-source yes + --replay-mode application + --section PmSampling + --section SourceCounters + --rule LocalMemoryUsage + --launch-skip 0 + --launch-count 1 + --lockstep-kernel-launch + --communicator tcp + --clock-control none + --pm-sampling-interval 1000 + --pm-sampling-max-passes 1 + --disable-pm-warp-sampling + --communicator-tcp-num-peers "$num_processes" + --kill yes + --app-replay-buffer memory +) + +echo "Run Job" + +for ((i = 0; i < num_processes; ++i)); do + ncu ${ncu_args[@]} -o "${output_dir%/}/mega-moe.$i" \ + python tests/test_mega_moe.py \ + --local-rank-idx=$i \ + --ncu-profile-only \ + "${python_args[@]}" & +done + +echo "Waiting" +wait +echo "Done" diff --git a/tests/test_attention.py b/tests/test_attention.py index 6df1fc4..479da5b 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -253,36 +253,61 @@ def test_paged_mqa_logits(): def enumerate_paged_mqa_logits(): arch_major = get_arch_major() - for is_fp4 in ((True, False) if arch_major == 10 else (False, )): - for logits_dtype in (torch.float, torch.bfloat16): - for block_kv in ((32, 64) if arch_major == 10 else (64, )): - for use_2d_context_lens, clean_logits in [(True, False)]: - for batch_size in (256, ): - for next_n in (1, 2, 4, 5, 6) if arch_major == 10 else (1, 2): - for num_heads, head_dim in [(64, 128)]: - for avg_kv in (8192, 32768): - yield is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, num_heads, head_dim, avg_kv + for is_varlen in ((True, False) if arch_major == 10 else (False, )): + for is_fp4 in ((True, False) if arch_major == 10 else (False, )): + for logits_dtype in (torch.float, torch.bfloat16): + for block_kv in ((32, 64) if arch_major == 10 else (64, )): + for use_2d_context_lens, clean_logits in [(True, False)]: + for batch_size in (256, ): + for next_n in ((1, ) if is_varlen else ((1, 2, 4, 5, 6) if arch_major == 10 else (1, 2))): + for max_tokens_per_batch in ((1, 4, 10) if is_varlen else (1, )): + for num_heads, head_dim in [(64, 128)]: + for avg_kv in (8192, 32768): + yield is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv print('Testing FP8/FP4 Paged MQA Logits:') max_model_len = 111 * 1024 num_total_blocks = max_model_len * 5 - for is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, num_heads, head_dim, avg_kv in enumerate_paged_mqa_logits(): + for is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv in enumerate_paged_mqa_logits(): + # Varlen: flatten raw_batch_size sequences with variable tokens into (batch_size, 1, ...) + raw_batch_size, raw_next_n = batch_size, next_n + if is_varlen: + tokens_per_seq = torch.randint(1, max_tokens_per_batch + 1, (raw_batch_size,), device='cuda', dtype=torch.int) + indices = torch.arange(raw_batch_size, device='cuda', dtype=torch.int).repeat_interleave(tokens_per_seq) + batch_size, next_n = tokens_per_seq.sum().item(), 1 + else: + tokens_per_seq, indices = None, None + # Generate random inputs q = torch.randn((batch_size, next_n, num_heads, head_dim), device='cuda', dtype=torch.bfloat16) kv_cache = torch.randn((num_total_blocks, block_kv, 1, head_dim), device='cuda', dtype=torch.bfloat16) weights = torch.randn((batch_size * next_n, num_heads), device='cuda', dtype=torch.float) - context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size,), device='cuda', dtype=torch.int) + context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (raw_batch_size,), device='cuda', dtype=torch.int) - # Assign block tables - num_blocks_per_query = ceil_div(context_lens, block_kv) - block_table = torch.empty((batch_size, num_blocks_per_query.max().item()), device='cuda', dtype=torch.int) + if is_varlen: + max_ctx_len_per_seq = context_lens + (tokens_per_seq - 1) + else: + max_ctx_len_per_seq = context_lens + + # Assign block tables (per-sequence, sized by the largest ctx_len within the sequence) + seq_sum_lens = context_lens.sum().item() + num_blocks_per_query = ceil_div(max_ctx_len_per_seq, block_kv) + block_table = torch.empty((raw_batch_size, num_blocks_per_query.max().item()), device='cuda', dtype=torch.int) block_idx_pool = torch.randperm(num_total_blocks, device='cuda', dtype=torch.int) offset = 0 for i, num_blocks in enumerate(num_blocks_per_query.tolist()): block_table[i, :num_blocks] = block_idx_pool[offset : offset + num_blocks] offset += num_blocks + if is_varlen: + context_lens = context_lens.repeat_interleave(tokens_per_seq) + offsets_within_seq = torch.cat([ + torch.arange(n.item(), device='cuda', dtype=torch.int) + for n in tokens_per_seq + ]) + context_lens = context_lens + offsets_within_seq + block_table = block_table.repeat_interleave(tokens_per_seq, dim=0) # Calculate reference logits ref_logits = ref_paged_mqa_logits(q, kv_cache, weights, context_lens, block_table, max_model_len, use_2d_context_lens) @@ -304,9 +329,14 @@ def test_paged_mqa_logits(): # Prepare masks and context lengths with NextN positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1) if use_2d_context_lens: - context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int() - # Ensure last token matches actual length - context_lens_nextn[:, -1] = context_lens + if is_varlen: + # Varlen: context_lens is already per-token (shape [total_tokens]); + # just reshape to (total_tokens, 1) so each token keeps its own ctx_len. + context_lens_nextn = context_lens.view(-1, 1) + else: + context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int() + # Ensure last token matches actual length + context_lens_nextn[:, -1] = context_lens ref_neginf_mask = ~(positions < context_lens_nextn.view(-1, 1)) else: context_lens_nextn = context_lens @@ -318,8 +348,9 @@ def test_paged_mqa_logits(): kernel_kwargs = dict( q=q_in, kv_cache=kv_in, weights=weights, context_lens=context_lens_nextn, block_table=block_table, - schedule_meta=deep_gemm.get_paged_mqa_logits_metadata(context_lens_nextn, block_kv, deep_gemm.get_num_sms()), - max_context_len=max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype + schedule_meta=deep_gemm.get_paged_mqa_logits_metadata(context_lens_nextn, block_kv, deep_gemm.get_num_sms(), indices=indices), + max_context_len=max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype, + indices=indices, ) logits = deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs) @@ -342,11 +373,15 @@ def test_paged_mqa_logits(): sum_lens = context_lens.sum().item() tflops_calc = 2 * sum_lens * next_n * num_heads * head_dim / 1e12 kv_bytes_per_token = head_dim / (2 if is_fp4 else 1) + 4 - total_bytes = count_bytes(q, weights) + sum_lens * kv_bytes_per_token + (sum_lens * next_n * logits_dtype.itemsize) + # KV is read once per sequence; for varlen sum_lens overcounts (per-token), so use seq_sum_lens + kv_sum_lens = seq_sum_lens if is_varlen else sum_lens + total_bytes = count_bytes(q, weights) + kv_sum_lens * kv_bytes_per_token + (sum_lens * next_n * logits_dtype.itemsize) t, clean_t = bench_kineto(lambda: deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs), ('paged_mqa_logits', 'clean_logits')) - print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, BLOCK_KV={block_kv}, BSZ={batch_size:3}, NextN={next_n:1}, H={num_heads:2}, D={head_dim:2}, L={avg_kv:6}: ' + print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, BLOCK_KV={block_kv}, BSZ={raw_batch_size:3}, NextN={raw_next_n:1}, H={num_heads:2}, D={head_dim:2}, L={avg_kv:6}: ' f'{tflops_calc / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, {total_bytes / t / 1e9:4.0f} GB/s', end='') + if is_varlen: + print(f' | Varlen, MaxTPB={max_tokens_per_batch}, NumTokens={batch_size}', end='') print(f' | clean: {clean_t*1e6:3.0f} us' if clean_logits else '') print() diff --git a/tests/test_mega_moe.py b/tests/test_mega_moe.py index ab9fd4b..e74b65e 100644 --- a/tests/test_mega_moe.py +++ b/tests/test_mega_moe.py @@ -9,24 +9,28 @@ from typing import Tuple import deep_gemm from deep_gemm.utils import per_token_cast_to_fp4, per_token_cast_to_fp8 from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather -from deep_gemm.testing import bench, bench_kineto, calc_diff +from deep_gemm.testing import bench_kineto -# Load legacy implements from third-party -# noinspection PyBroadException -try: - import deep_ep - import importlib.util - from tilelang.profiler.bench import do_bench - spec = importlib.util.spec_from_file_location( - 'tilelang_ops', - os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py')) - tilelang_ops = importlib.util.module_from_spec(spec) - sys.modules['tilelang_ops'] = tilelang_ops - spec.loader.exec_module(tilelang_ops) - is_legacy_loaded = True -except Exception as ex: - print(f'Failed to load legacy code: {ex}, skip baseline benchmarking') - is_legacy_loaded = False + +def import_baseline(): + # Load legacy implements from third-party + deep_ep, tilelang_ops, do_bench, is_legacy_loaded = None, None, None, False + # noinspection PyBroadException + try: + import deep_ep + import importlib.util + from tilelang.profiler.bench import do_bench + spec = importlib.util.spec_from_file_location( + 'tilelang_ops', + os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py')) + tilelang_ops = importlib.util.module_from_spec(spec) + sys.modules['tilelang_ops'] = tilelang_ops + spec.loader.exec_module(tilelang_ops) + is_legacy_loaded = True + except Exception as ex: + dist_print(f'Failed to load legacy code: {ex}, skip baseline benchmarking', once_in_node=True) + dist_print(once_in_node=True) + return deep_ep, tilelang_ops, do_bench, is_legacy_loaded # TODO: skip the test for SM90 @@ -51,29 +55,13 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden ) - dist_print('Config:', once_in_node=True) - dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True) - dist_print(f' > Hidden: {hidden}', once_in_node=True) - dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True) - dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True) - dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True) - dist_print(once_in_node=True) - - # Non-overlapped baseline: EP dispatch + GEMM + EP combine - alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() - deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) - ep_buffer = deep_ep.ElasticBuffer( - group, - num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden, - num_topk=num_topk, use_fp8_dispatch=True, - explicitly_destroy=True, - allow_multiple_reduction=False, - gpu_timeout_secs=10, cpu_timeout_secs=30 - ) if is_legacy_loaded else None # Create inputs + # noinspection PyGlobalUndefined def create_inputs(): global x, topk_idx, topk_weights, l1_weights, l2_weights, transformed_l1_weights, transformed_l2_weights + global cumulative_local_expert_recv_stats_fused + global cumulative_local_expert_recv_stats_baseline x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') l1_weights = torch.randn( (num_experts_per_rank, intermediate_hidden * 2, hidden), dtype=torch.bfloat16, device='cuda') @@ -81,6 +69,9 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): (num_experts_per_rank, hidden, intermediate_hidden), dtype=torch.bfloat16, device='cuda') scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + cumulative_local_expert_recv_stats_fused = torch.randint( + 0, 100, (num_experts_per_rank, ), dtype=torch.int, device='cuda') + cumulative_local_expert_recv_stats_baseline = cumulative_local_expert_recv_stats_fused.clone() if args.masked_ratio > 0: rand_mask = torch.rand_like(topk_idx, dtype=torch.float) topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) @@ -109,12 +100,67 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): l2_weights = cast_grouped_weights_to_fp4(l2_weights) transformed_l1_weights, transformed_l2_weights = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights) + # Run fused mega MoE + # NOTES: copy x into buffer before each call because debug mode zeros the entire buffer + def run_fused(): + buffer.x[:num_tokens].copy_(x[0]) + buffer.x_sf[:num_tokens].copy_(x[1]) + buffer.topk_idx[:num_tokens].copy_(topk_idx) + buffer.topk_weights[:num_tokens].copy_(topk_weights) + + y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + # noinspection PyTypeChecker + deep_gemm.fp8_fp4_mega_moe( + y, + transformed_l1_weights, transformed_l2_weights, + buffer, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_fused, + activation_clamp=args.activation_clamp, + fast_math=bool(args.fast_math) + ) + return y, cumulative_local_expert_recv_stats_fused + + dist_print('Config:', once_in_node=True) + dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True) + dist_print(f' > Hidden: {hidden}', once_in_node=True) + dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True) + dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True) + dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True) + dist_print(once_in_node=True) + + # Only do NCU profiling + if args.ncu_profile_only: + create_inputs() + dist_print(f'Run fused kernel:', once_in_node=True) + run_fused() + dist_print(f' > Done, exiting', once_in_node=True) + + # Destroy and exit + dist.barrier() + buffer.destroy() + dist.destroy_process_group() + return + + # Non-overlapped baseline: EP dispatch + GEMM + EP combine + deep_ep, tilelang_ops, tilelang_bench, is_legacy_loaded = import_baseline() + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + ep_buffer = deep_ep.ElasticBuffer( + group, + num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden, + num_topk=num_topk, use_fp8_dispatch=True, + explicitly_destroy=True, + allow_multiple_reduction=False, + gpu_timeout_secs=10, cpu_timeout_secs=30 + ) if is_legacy_loaded else None + def run_baseline(): recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch( x, topk_idx=topk_idx, topk_weights=topk_weights, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_baseline, num_experts=num_experts, expert_alignment=alignment, do_cpu_sync=False, do_handle_copy=False, - do_expand=True, use_tma_aligned_col_major_sf=True + do_expand=True, use_tma_aligned_col_major_sf=True, ) n = recv_x[0].size(0) l1_y = torch.empty((n, intermediate_hidden * 2), dtype=torch.bfloat16, device='cuda') @@ -138,26 +184,7 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( l1_y, l2_weights, l2_y, handle.psum_num_recv_tokens_per_expert, use_psum_layout=True, recipe=(1, 1, 32)) - return ep_buffer.combine(l2_y, handle=handle)[0] - - # Run fused mega MoE - # NOTES: copy x into buffer before each call because debug mode zeros the entire buffer - def run_fused(): - buffer.x[:num_tokens].copy_(x[0]) - buffer.x_sf[:num_tokens].copy_(x[1]) - buffer.topk_idx[:num_tokens].copy_(topk_idx) - buffer.topk_weights[:num_tokens].copy_(topk_weights) - - y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - # noinspection PyTypeChecker - deep_gemm.fp8_fp4_mega_moe( - y, - transformed_l1_weights, transformed_l2_weights, - buffer, - activation_clamp=args.activation_clamp, - fast_math=bool(args.fast_math) - ) - return y + return ep_buffer.combine(l2_y, handle=handle)[0], cumulative_local_expert_recv_stats_baseline # Check correctness (must be bitwise identical) num_correctness_tests = 1 if args.num_correctness_tests is None else args.num_correctness_tests @@ -166,34 +193,36 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): dist_print('Running correctness tests:', once_in_node=True) for i in range(num_correctness_tests): create_inputs() - assert torch.equal(run_fused(), run_baseline()) + for fused_result, baseline_result in zip(run_fused(), run_baseline()): + assert torch.equal(fused_result, baseline_result) if (i + 1) % 100 == 0 or i == num_correctness_tests - 1: - dist_print(f' > Correctness test #{i + 1}/{args.num_correctness_tests} passed', once_in_node=True) + dist_print(f' > Correctness test #{i + 1}/{num_correctness_tests} passed', once_in_node=True) dist_print(once_in_node=True) else: create_inputs() # Count local received tokens gathered_topk_idx = uneven_all_gather(topk_idx, group=group) - num_recv_tokens = (rank_idx * num_experts_per_rank <= gathered_topk_idx) & \ - (gathered_topk_idx < (rank_idx + 1) * num_experts_per_rank) - num_recv_tokens = num_recv_tokens.sum().item() + gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | \ + (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1 + num_recv_tokens = (gathered_topk_idx != -1).sum().item() # Benchmark t_fused = bench_kineto( run_fused, 'mega_moe', barrier=lambda: ep_buffer.barrier(use_comm_stream=False) if ep_buffer else dist.barrier(), trace_path=None if not args.dump_profile_traces else f'{args.dump_profile_traces}/mega_moe_rank{rank_idx}.json') - t_baseline = do_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0 + t_baseline = tilelang_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0 # TFLOPS: 3 matmuls (L1 left, L1 right, L2), each 2 * M * N * K safe_div = lambda a, b: float('nan') if b == 0 else a / b tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused) # HBM bytes: weights (FP4 packed = 0.5 bytes) + activations (FP8 = 1 byte) + output (BF16 = 2 bytes) + num_touched_experts = torch.unique(gathered_topk_idx.flatten()).numel() - 1 # NOTES minus 1 to exclude "-1" num_hbm_bytes = ( - num_experts_per_rank * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4) - num_experts_per_rank * hidden * intermediate_hidden // 2 + # L2 weights (FP4) + num_touched_experts * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4) + num_touched_experts * hidden * intermediate_hidden // 2 + # L2 weights (FP4) num_recv_tokens * hidden + # L1 acts read (FP8) num_recv_tokens * intermediate_hidden + # L1 output write (FP8) num_recv_tokens * intermediate_hidden + # L2 acts read (FP8) @@ -230,7 +259,9 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Test PyTorch symmetric memory') + # Resource settings + parser.add_argument('--ncu-profile-only', action='store_true', help='Only run profiling without correctness test') parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') # Model settings