Add various optimizations and Mega MoE benchmarks (#316)
* Merge with private repo * Add Mega MoE Benchmark * Minor fix * Update --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
58
README.md
58
README.md
@@ -9,15 +9,15 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
||||
## News
|
||||
|
||||
- 2026.04.16: Mega MoE, FP8xFP4 GEMM, FP4 Indexer, PDL, faster JIT compilation and more.
|
||||
- Performance comparison will be posted later.
|
||||
- Please see [#304](https://github.com/deepseek-ai/DeepGEMM/pull/304) for more details.
|
||||
- Please see [#304](https://github.com/deepseek-ai/DeepGEMM/pull/304) for more details.
|
||||
- For Mega MoE benchmarks, refer to [#316](https://github.com/deepseek-ai/DeepGEMM/pull/316).
|
||||
- 2025.09.28: DeepGEMM now supports scoring kernels (weighted ReLU MQA logits) for the lightning indexer for DeepSeek v3.2.
|
||||
- Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details.
|
||||
- Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details.
|
||||
- 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module.
|
||||
- NVRTC and post-compilation SASS optimization are all disabled.
|
||||
- NVRTC will be supported later.
|
||||
- As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported.
|
||||
- Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details.
|
||||
- NVRTC and post-compilation SASS optimization are all disabled.
|
||||
- NVRTC will be supported later.
|
||||
- As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported.
|
||||
- Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details.
|
||||
- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details.
|
||||
- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases).
|
||||
- 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details.
|
||||
@@ -30,9 +30,9 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
||||
- Python 3.8 or higher
|
||||
- Compilers with C++20 support
|
||||
- CUDA Toolkit:
|
||||
- CUDA 12.3 or higher for SM90
|
||||
- **We highly recommend 12.9 or higher for the best performance**
|
||||
- CUDA 12.9 or higher for SM100
|
||||
- CUDA 12.3 or higher for SM90
|
||||
- **We highly recommend 12.9 or higher for the best performance**
|
||||
- CUDA 12.9 or higher for SM100
|
||||
- PyTorch 2.1 or higher
|
||||
- CUTLASS 4.0 or higher (could be cloned by Git submodule)
|
||||
- `{fmt}` library (could be cloned by Git submodule)
|
||||
@@ -159,30 +159,30 @@ The library provides some utility functions besides the above kernels:
|
||||
The library also provides some environment variables, which may be useful:
|
||||
|
||||
- General
|
||||
- `DG_JIT_DEBUG`: `0` or `1`, print JIT debugging information, `0` by default
|
||||
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
|
||||
- `DG_JIT_DEBUG`: `0` or `1`, print JIT debugging information, `0` by default
|
||||
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
|
||||
- JIT cache
|
||||
- `DG_JIT_CACHE_DIR`: string, cache directory for compiled kernels, `$HOME/.deep_gemm` by default
|
||||
- `DG_JIT_CACHE_DIR`: string, cache directory for compiled kernels, `$HOME/.deep_gemm` by default
|
||||
- Compiler selection
|
||||
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC (faster compilation, may have lower performance for some cases), `0` by default
|
||||
- `DG_JIT_NVCC_COMPILER`: string, NVCC compiler path; defaults to `torch.utils.cpp_extension.CUDA_HOME`
|
||||
- `DG_JIT_CPP_STANDARD`: integer, C++ standard version, `20` by default
|
||||
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC (faster compilation, may have lower performance for some cases), `0` by default
|
||||
- `DG_JIT_NVCC_COMPILER`: string, NVCC compiler path; defaults to `torch.utils.cpp_extension.CUDA_HOME`
|
||||
- `DG_JIT_CPP_STANDARD`: integer, C++ standard version, `20` by default
|
||||
- Compiler output
|
||||
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print compilation commands, `0` by default
|
||||
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS output, `0` by default
|
||||
- `DG_JIT_PTXAS_CHECK`: `0` or `1`, assert no local memory usage in compiled kernels, `0` by default
|
||||
- `DG_JIT_PRINT_LOAD_TIME`: `0` or `1`, print kernel load time, `0` by default
|
||||
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print compilation commands, `0` by default
|
||||
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS output, `0` by default
|
||||
- `DG_JIT_PTXAS_CHECK`: `0` or `1`, assert no local memory usage in compiled kernels, `0` by default
|
||||
- `DG_JIT_PRINT_LOAD_TIME`: `0` or `1`, print kernel load time, `0` by default
|
||||
- Debug and profiling
|
||||
- `DG_JIT_WITH_LINEINFO`: `0` or `1`, embed source line info for profiling tools, `0` by default
|
||||
- `DG_JIT_DUMP_ASM`: `0` or `1`, dump both PTX and SASS, `0` by default
|
||||
- `DG_JIT_DUMP_PTX`: `0` or `1`, dump PTX output, `0` by default
|
||||
- `DG_JIT_DUMP_SASS`: `0` or `1`, dump SASS output, `0` by default
|
||||
- `DG_COMM_KERNEL_DEBUG`: `0` or `1`, zero symmetric buffer before each Mega MoE call for debugging, `0` by default
|
||||
- `DG_USE_NVIDIA_TOOLS`: `0` or `1`, skip internal profiling when running under external NVIDIA tools, `0` by default
|
||||
- `DG_JIT_WITH_LINEINFO`: `0` or `1`, embed source line info for profiling tools, `0` by default
|
||||
- `DG_JIT_DUMP_ASM`: `0` or `1`, dump both PTX and SASS, `0` by default
|
||||
- `DG_JIT_DUMP_PTX`: `0` or `1`, dump PTX output, `0` by default
|
||||
- `DG_JIT_DUMP_SASS`: `0` or `1`, dump SASS output, `0` by default
|
||||
- `DG_COMM_KERNEL_DEBUG`: `0` or `1`, zero symmetric buffer before each Mega MoE call for debugging, `0` by default
|
||||
- `DG_USE_NVIDIA_TOOLS`: `0` or `1`, skip internal profiling when running under external NVIDIA tools, `0` by default
|
||||
- Build options
|
||||
- `DG_SKIP_CUDA_BUILD`: `0` or `1`, skip CUDA extension build during installation, `0` by default
|
||||
- `DG_FORCE_BUILD`: `0` or `1`, force local build instead of downloading pre-built wheels, `0` by default
|
||||
- `DG_JIT_USE_RUNTIME_API`: `0` or `1`, use CUDA Runtime API for kernel loading (requires CUDA runtime >= 12.8), `0` by default
|
||||
- `DG_SKIP_CUDA_BUILD`: `0` or `1`, skip CUDA extension build during installation, `0` by default
|
||||
- `DG_FORCE_BUILD`: `0` or `1`, force local build instead of downloading pre-built wheels, `0` by default
|
||||
- `DG_JIT_USE_RUNTIME_API`: `0` or `1`, use CUDA Runtime API for kernel loading (requires CUDA runtime >= 12.8), `0` by default
|
||||
|
||||
For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.
|
||||
|
||||
|
||||
@@ -190,12 +190,13 @@ static torch::Tensor fp8_fp4_mqa_logits(const std::tuple<torch::Tensor, std::opt
|
||||
return logits;
|
||||
}
|
||||
|
||||
static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms) {
|
||||
static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms, const std::optional<torch::Tensor>& indices) {
|
||||
// NOTES: Only 2D context lens is supported for now
|
||||
DG_HOST_ASSERT(context_lens.dim() == 2);
|
||||
const bool is_context_lens_2d = true;
|
||||
const int batch_size = context_lens.size(0);
|
||||
const int next_n = context_lens.size(1);
|
||||
const bool is_varlen = indices.has_value();
|
||||
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(context_lens.is_contiguous());
|
||||
|
||||
@@ -204,9 +205,16 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_
|
||||
|
||||
// Dispatch implementation
|
||||
const auto arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 or arch_major == 10) {
|
||||
if (is_varlen) {
|
||||
const auto& indices_tensor = indices.value();
|
||||
DG_HOST_ASSERT(arch_major == 10 and next_n == 1 and (block_kv == 64 or block_kv == 32));
|
||||
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
|
||||
DG_HOST_ASSERT(indices_tensor.is_contiguous());
|
||||
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
|
||||
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, true, indices_tensor.data_ptr<int>());
|
||||
} else if (arch_major == 9 or arch_major == 10) {
|
||||
DG_HOST_ASSERT(block_kv == 64 or (arch_major == 10 and block_kv == 32));
|
||||
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d);
|
||||
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, false, nullptr);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
@@ -222,7 +230,8 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
|
||||
const torch::Tensor& schedule_meta,
|
||||
const int& max_context_len,
|
||||
const bool& clean_logits,
|
||||
const at::ScalarType& logits_dtype) {
|
||||
const at::ScalarType& logits_dtype,
|
||||
const std::optional<torch::Tensor>& indices) {
|
||||
const auto [q_fp, q_sf] = q;
|
||||
const bool is_fp4 = q_sf.has_value();
|
||||
|
||||
@@ -321,6 +330,17 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
|
||||
DG_HOST_ASSERT(block_table.stride(1) == 1);
|
||||
DG_HOST_ASSERT(block_table.scalar_type() == torch::kInt);
|
||||
|
||||
// Check indices
|
||||
const bool is_varlen = indices.has_value();
|
||||
const auto arch_major = device_runtime->get_arch_major();
|
||||
const auto indices_tensor = indices.value_or(torch::Tensor());
|
||||
if (is_varlen) {
|
||||
DG_HOST_ASSERT(arch_major == 10 and next_n == 1);
|
||||
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
|
||||
DG_HOST_ASSERT(indices_tensor.is_contiguous());
|
||||
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
|
||||
}
|
||||
|
||||
// Check schedule metadata
|
||||
auto [_schedule_meta_size, _meta_info_size] = get_shape<2>(schedule_meta);
|
||||
DG_HOST_ASSERT(_schedule_meta_size == num_sms + 1 and _meta_info_size == 2);
|
||||
@@ -344,15 +364,14 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
|
||||
DG_HOST_ASSERT(logits_dtype == torch::kFloat32 or logits_dtype == torch::kBFloat16);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto arch_major = device_runtime->get_arch_major();
|
||||
if (is_fp4 and arch_major == 10) {
|
||||
sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, schedule_meta,
|
||||
sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta,
|
||||
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
|
||||
aligned_max_context_len, block_table_stride, num_sms, split_kv);
|
||||
is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv);
|
||||
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) {
|
||||
smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, schedule_meta,
|
||||
smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta,
|
||||
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
|
||||
aligned_max_context_len, block_table_stride, num_sms, split_kv);
|
||||
is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
@@ -386,10 +405,11 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
const torch::Tensor& block_table,
|
||||
const torch::Tensor& schedule_meta,
|
||||
const int& max_context_len,
|
||||
const bool& clean_logits) {
|
||||
const bool& clean_logits,
|
||||
const std::optional<torch::Tensor>& indices) {
|
||||
return fp8_fp4_paged_mqa_logits(std::make_tuple(q, std::nullopt), fused_kv_cache, weights,
|
||||
context_lens, block_table, schedule_meta,
|
||||
max_context_len, clean_logits, torch::kFloat);
|
||||
max_context_len, clean_logits, torch::kFloat, indices);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -407,13 +427,15 @@ static void register_apis(pybind11::module_& m) {
|
||||
py::arg("max_seqlen_k") = 0,
|
||||
py::arg("logits_dtype") = torch::kFloat32);
|
||||
m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata,
|
||||
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"));
|
||||
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"),
|
||||
py::arg("indices") = std::nullopt);
|
||||
m.def("fp8_fp4_paged_mqa_logits", &fp8_fp4_paged_mqa_logits,
|
||||
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
|
||||
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
|
||||
py::arg("max_context_len"),
|
||||
py::arg("clean_logits") = false,
|
||||
py::arg("logits_dtype") = torch::kFloat32);
|
||||
py::arg("logits_dtype") = torch::kFloat32,
|
||||
py::arg("indices") = std::nullopt);
|
||||
// Legacy API
|
||||
m.def("fp8_mqa_logits", &fp8_mqa_logits,
|
||||
py::arg("q"), py::arg("kv"), py::arg("weights"),
|
||||
@@ -423,7 +445,8 @@ static void register_apis(pybind11::module_& m) {
|
||||
m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits,
|
||||
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
|
||||
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
|
||||
py::arg("max_context_len"), py::arg("clean_logits") = false);
|
||||
py::arg("max_context_len"), py::arg("clean_logits") = false,
|
||||
py::arg("indices") = std::nullopt);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,10 @@
|
||||
|
||||
namespace deep_gemm::mega {
|
||||
|
||||
static int get_token_alignment_for_mega_moe() {
|
||||
return layout::kLCMCandidateBlockM;
|
||||
}
|
||||
|
||||
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
|
||||
get_symm_buffer_size_for_mega_moe(
|
||||
const int& num_ranks, const int& num_experts,
|
||||
@@ -20,8 +24,7 @@ get_symm_buffer_size_for_mega_moe(
|
||||
DG_HOST_ASSERT(num_experts % num_ranks == 0);
|
||||
|
||||
// Workspace bytes
|
||||
const auto block_m = get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
|
||||
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk, block_m);
|
||||
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
|
||||
|
||||
// Layouts
|
||||
const auto fp8_token_layout = layout::Data(hidden);
|
||||
@@ -49,14 +52,20 @@ get_symm_buffer_size_for_mega_moe(
|
||||
|
||||
// Buffer configs
|
||||
const auto num_max_pool_tokens = static_cast<int>(workspace.num_max_pool_tokens);
|
||||
const auto num_padded_sf_pool_tokens = layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m);
|
||||
int num_max_padded_sf_pool_tokens = 0;
|
||||
for (int block_m: layout::kCandidateBlockM) {
|
||||
num_max_padded_sf_pool_tokens = std::max(
|
||||
num_max_padded_sf_pool_tokens,
|
||||
layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m)
|
||||
);
|
||||
}
|
||||
|
||||
// L1 input buffer
|
||||
const auto l1_token_buffer = layout::Buffer(
|
||||
fp8_token_layout, 1, num_max_pool_tokens,
|
||||
input_topk_weights_buffer.get_end_ptr());
|
||||
const auto l1_sf_buffer = layout::Buffer(
|
||||
fp8_sf_layout, 1, num_padded_sf_pool_tokens,
|
||||
fp8_sf_layout, 1, num_max_padded_sf_pool_tokens,
|
||||
l1_token_buffer.get_end_ptr());
|
||||
const auto l1_topk_weights_buffer = layout::Buffer(
|
||||
l1_topk_weights_layout, 1, num_max_pool_tokens,
|
||||
@@ -67,7 +76,7 @@ get_symm_buffer_size_for_mega_moe(
|
||||
fp8_intermediate_token_layout, 1, num_max_pool_tokens,
|
||||
l1_topk_weights_buffer.get_end_ptr());
|
||||
const auto l2_sf_buffer = layout::Buffer(
|
||||
fp8_intermediate_sf_layout, 1, num_padded_sf_pool_tokens,
|
||||
fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens,
|
||||
l2_token_buffer.get_end_ptr());
|
||||
|
||||
// Combine input buffer: BF16 tokens for cross-rank combine
|
||||
@@ -77,7 +86,7 @@ get_symm_buffer_size_for_mega_moe(
|
||||
|
||||
// Check SF buffer requirements
|
||||
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
|
||||
DG_HOST_ASSERT(num_padded_sf_pool_tokens % 4 == 0);
|
||||
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);
|
||||
|
||||
// Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer
|
||||
// NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major
|
||||
@@ -104,8 +113,8 @@ get_symm_buffer_size_for_mega_moe(
|
||||
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
|
||||
auto l1_acts_sf = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
|
||||
{num_padded_sf_pool_tokens, hidden / 128},
|
||||
{1, num_padded_sf_pool_tokens},
|
||||
{num_max_padded_sf_pool_tokens, hidden / 128},
|
||||
{1, num_max_padded_sf_pool_tokens},
|
||||
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
|
||||
auto l2_acts = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
|
||||
@@ -113,8 +122,8 @@ get_symm_buffer_size_for_mega_moe(
|
||||
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
|
||||
auto l2_acts_sf = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
|
||||
{num_padded_sf_pool_tokens, intermediate_hidden / 128},
|
||||
{1, num_padded_sf_pool_tokens},
|
||||
{num_max_padded_sf_pool_tokens, intermediate_hidden / 128},
|
||||
{1, num_max_padded_sf_pool_tokens},
|
||||
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
|
||||
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
|
||||
};
|
||||
@@ -123,8 +132,9 @@ get_symm_buffer_size_for_mega_moe(
|
||||
|
||||
static void fp8_fp4_mega_moe(
|
||||
const torch::Tensor& y,
|
||||
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_,
|
||||
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_,
|
||||
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_tuple,
|
||||
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_tuple,
|
||||
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||
const torch::Tensor& sym_buffer,
|
||||
const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx,
|
||||
const int& num_max_tokens_per_rank,
|
||||
@@ -132,9 +142,10 @@ static void fp8_fp4_mega_moe(
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::string& activation,
|
||||
const std::optional<float>& activation_clamp_opt,
|
||||
const bool& fast_math) {
|
||||
const auto [l1_weights, l1_weights_sf] = l1_weights_;
|
||||
const auto [l2_weights, l2_weights_sf] = l2_weights_;
|
||||
const bool& fast_math
|
||||
) {
|
||||
const auto [l1_weights, l1_weights_sf] = l1_weights_tuple;
|
||||
const auto [l2_weights, l2_weights_sf] = l2_weights_tuple;
|
||||
|
||||
// Config checks
|
||||
const auto num_tokens = static_cast<int>(y.size(0));
|
||||
@@ -161,13 +172,20 @@ static void fp8_fp4_mega_moe(
|
||||
DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden);
|
||||
DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous());
|
||||
|
||||
// Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment
|
||||
// Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment
|
||||
constexpr int kGranMN = 1, kGranK = 32;
|
||||
check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK,
|
||||
num_experts_per_rank, true, false, torch::kInt);
|
||||
check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK,
|
||||
num_experts_per_rank, true, false, torch::kInt);
|
||||
|
||||
// Check stats counter
|
||||
if (cumulative_local_expert_recv_stats.has_value()) {
|
||||
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank);
|
||||
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous());
|
||||
}
|
||||
|
||||
// Check buffer bytes
|
||||
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
|
||||
const auto num_experts_ = num_experts_per_rank * num_ranks;
|
||||
@@ -175,7 +193,7 @@ static void fp8_fp4_mega_moe(
|
||||
num_ranks, num_experts,
|
||||
num_max_tokens_per_rank, num_topk,
|
||||
hidden, intermediate_hidden,
|
||||
true, "swiglu");
|
||||
true, activation);
|
||||
DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast<size_t>(num_required_bytes));
|
||||
DG_HOST_ASSERT(num_experts == num_experts_);
|
||||
|
||||
@@ -189,6 +207,7 @@ static void fp8_fp4_mega_moe(
|
||||
l2_acts, l2_acts_sf,
|
||||
l1_weights, l2_weights,
|
||||
l1_weights_sf, l2_weights_sf,
|
||||
cumulative_local_expert_recv_stats,
|
||||
sym_buffer_ptrs,
|
||||
rank_idx, num_max_tokens_per_rank,
|
||||
num_experts_per_rank,
|
||||
@@ -207,7 +226,7 @@ static void fp8_fp4_mega_moe(
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
#if DG_TENSORMAP_COMPATIBLE
|
||||
m.def("get_block_m_for_mega_moe", &get_block_m_for_mega_moe);
|
||||
m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe);
|
||||
m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe);
|
||||
m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe);
|
||||
#endif
|
||||
|
||||
@@ -55,38 +55,68 @@ struct MegaMoEConfig {
|
||||
}
|
||||
};
|
||||
|
||||
static int get_block_m_for_mega_moe(const int& num_ranks, const int& num_experts,
|
||||
const int& num_max_tokens_per_rank, const int& num_topk) {
|
||||
// TODO: compute based on configs
|
||||
return 192;
|
||||
static std::tuple<int, int, int, int> get_block_config_for_mega_moe(
|
||||
const int& num_ranks, const int& num_experts,
|
||||
const int& num_max_tokens_per_rank, const int& num_topk,
|
||||
const int& num_tokens) {
|
||||
const auto& [cluster_size, block_m, store_block_m, num_epilogue_warpgroups] = [&]() -> std::tuple<int, int, int, int> {
|
||||
float num_expected_tokens_per_expert = static_cast<float>(num_tokens) * num_ranks * num_topk / num_experts;
|
||||
if (num_expected_tokens_per_expert <= 8.5) {
|
||||
// Really small token-per-expert (e.g. RL long-tail rollout), use the smallest block_m
|
||||
return {2, 16, 8, 2};
|
||||
} else if (num_expected_tokens_per_expert <= 16.5) {
|
||||
// Small batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 128
|
||||
return {2, 32, 16, 2};
|
||||
} else if (num_expected_tokens_per_expert <= 32.5) {
|
||||
// Medium batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 256
|
||||
return {2, 64, 32, 1};
|
||||
} else if (num_expected_tokens_per_expert <= 64.5) {
|
||||
// Large batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 512
|
||||
return {2, 96, 16, 2};
|
||||
} else if (num_expected_tokens_per_expert <= 96.5) {
|
||||
// Medium batch size, Medium EP, decoding, e.g. 6/384 experts, EP16, bsz 256, or EP32, bsz128
|
||||
return {2, 128, 32, 2};
|
||||
} else {
|
||||
// Prefill, or large EP decoding
|
||||
return {2, 192, 32, 2};
|
||||
}
|
||||
}();
|
||||
|
||||
// Check whether our `block_m` lies in `kCandidateBlockM`
|
||||
DG_HOST_ASSERT(std::any_of(
|
||||
layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs,
|
||||
[=](const auto& candidate) { return candidate == block_m; })
|
||||
);
|
||||
|
||||
// Return configs
|
||||
return {cluster_size, block_m, store_block_m, num_epilogue_warpgroups * 128};
|
||||
}
|
||||
|
||||
static int get_num_experts_per_wave_for_mega_moe(
|
||||
const int& num_experts_per_rank, const int& num_tokens, const int& num_topk,
|
||||
const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) {
|
||||
|
||||
float expected_tokens_per_expert = static_cast<float>(num_tokens) * num_topk / num_experts_per_rank;
|
||||
if (expected_tokens_per_expert < 1) {
|
||||
// Most experts don't have tokens, calculate all experts at once
|
||||
return num_experts_per_rank;
|
||||
}
|
||||
|
||||
// Reduce per-expert block count by this factor since uneven routing leaves some experts with fewer tokens
|
||||
constexpr int kImbalanceFactor = 2;
|
||||
|
||||
// TODO: support num_experts_per_rank > 32
|
||||
// Find the largest divisor of num_experts_per_rank that fits in 32 as the upper bound
|
||||
int max_num_experts_per_wave = std::min(32, num_experts_per_rank);
|
||||
while (max_num_experts_per_wave > 1 and num_experts_per_rank % max_num_experts_per_wave != 0)
|
||||
-- max_num_experts_per_wave;
|
||||
|
||||
// Count L1 blocks per expert assuming tokens are evenly spread across experts
|
||||
const int expected_tokens_per_expert =
|
||||
num_tokens * num_topk / num_experts_per_rank + 1;
|
||||
const int num_m_blocks = ceil_div(expected_tokens_per_expert, block_m);
|
||||
const int num_n_blocks = intermediate_hidden / block_n;
|
||||
const int num_m_blocks = ceil_div(static_cast<int>(std::ceil(expected_tokens_per_expert)), block_m);
|
||||
const int num_n_blocks = (2 * intermediate_hidden) / block_n;
|
||||
const int num_l1_blocks_per_expert = num_m_blocks * num_n_blocks;
|
||||
|
||||
// Pick the smallest value whose total blocks (after imbalance reduction) can keep all SMs busy
|
||||
int num_experts_per_wave = num_l1_blocks_per_expert > 0
|
||||
? ceil_div(kImbalanceFactor * num_sms, num_l1_blocks_per_expert) : 1;
|
||||
num_experts_per_wave = std::min(num_experts_per_wave, max_num_experts_per_wave);
|
||||
num_experts_per_wave = std::min(num_experts_per_wave, num_experts_per_rank);
|
||||
|
||||
// Round up to the nearest divisor of num_experts_per_rank so every wave processes the same count
|
||||
while (num_experts_per_wave < max_num_experts_per_wave and num_experts_per_rank % num_experts_per_wave != 0)
|
||||
while (num_experts_per_wave < num_experts_per_rank and num_experts_per_rank % num_experts_per_wave != 0)
|
||||
++ num_experts_per_wave;
|
||||
|
||||
return num_experts_per_wave;
|
||||
@@ -148,18 +178,18 @@ static std::pair<int, int> get_pipeline_config_for_mega_moe(
|
||||
static MegaMoEConfig get_mega_moe_config(
|
||||
const int& num_ranks, const int& num_experts, const int& num_experts_per_rank,
|
||||
const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk,
|
||||
const int& hidden, const int& intermediate_hidden) {
|
||||
// Block tiling
|
||||
const int block_m = get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
|
||||
const int& hidden, const int& intermediate_hidden,
|
||||
const int& num_padded_sf_pool_tokens) {
|
||||
// Block config
|
||||
const auto [cluster_size, block_m, store_block_m, num_epilogue_threads] =
|
||||
get_block_config_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens);
|
||||
const int block_n = 128;
|
||||
const int block_k = 128;
|
||||
const int load_block_m = block_m / 2;
|
||||
const int load_block_n = block_n;
|
||||
const int store_block_m = 32;
|
||||
const auto [sf_block_m, sf_block_n] = SM100ArchSpec::get_sf_uttcp_aligned_block_sizes(block_m, block_n, MmaKind::MXFP8FP4);
|
||||
const int num_max_pool_tokens = layout::get_num_max_pool_tokens(
|
||||
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank, block_m);
|
||||
const int num_padded_sf_pool_tokens = layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m);
|
||||
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
|
||||
// NOTES: FP8 activations and FP4 weights (unpacked to 8-bit in smem) both use 128B swizzle
|
||||
const int swizzle_acts_mode = 128;
|
||||
const int swizzle_weights_mode = 128;
|
||||
@@ -173,7 +203,6 @@ static MegaMoEConfig get_mega_moe_config(
|
||||
// Thread layout
|
||||
const int num_dispatch_threads = 128;
|
||||
const int num_non_epilogue_threads = 128;
|
||||
const int num_epilogue_threads = 256;
|
||||
|
||||
// Pipeline
|
||||
const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe(
|
||||
|
||||
@@ -29,6 +29,7 @@ public:
|
||||
|
||||
// Runtime arguments
|
||||
void* y;
|
||||
int* cumulative_local_expert_recv_stats;
|
||||
int num_tokens;
|
||||
layout::SymBuffer<> sym_buffer_ptrs;
|
||||
|
||||
@@ -91,6 +92,7 @@ static void __instantiate_kernel() {{
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.y,
|
||||
args.cumulative_local_expert_recv_stats,
|
||||
args.num_tokens,
|
||||
args.sym_buffer_ptrs,
|
||||
args.tensor_map_l1_acts,
|
||||
@@ -112,6 +114,7 @@ static void sm100_fp8_fp4_mega_moe(
|
||||
const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf,
|
||||
const torch::Tensor& l1_weights, const torch::Tensor& l2_weights,
|
||||
const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf,
|
||||
const std::optional<torch::Tensor> cumulative_local_expert_recv_stats,
|
||||
const std::vector<int64_t>& sym_buffer_ptrs,
|
||||
const int& rank_idx, const int& num_max_tokens_per_rank,
|
||||
const int& num_experts_per_rank,
|
||||
@@ -122,11 +125,12 @@ static void sm100_fp8_fp4_mega_moe(
|
||||
) {
|
||||
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
|
||||
const auto num_experts = num_experts_per_rank * num_ranks;
|
||||
const auto num_padded_sf_pool_tokens = static_cast<int>(l1_acts_sf.size(0));
|
||||
|
||||
// Heuristics
|
||||
const auto config = get_mega_moe_config(
|
||||
num_ranks, num_experts, num_experts_per_rank,
|
||||
num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden);
|
||||
num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden, num_padded_sf_pool_tokens);
|
||||
|
||||
// Make tensormap
|
||||
constexpr int kGranK = 32;
|
||||
@@ -175,6 +179,11 @@ static void sm100_fp8_fp4_mega_moe(
|
||||
config.block_n, kGranK,
|
||||
num_experts_per_rank, 0);
|
||||
|
||||
// Stats can be optional
|
||||
int* cumulative_local_expert_recv_stats_ptr = nullptr;
|
||||
if (cumulative_local_expert_recv_stats.has_value())
|
||||
cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr<int>();
|
||||
|
||||
// Launch
|
||||
const auto num_sms = device_runtime->get_num_sms();
|
||||
const SM100FP8FP4MegaMoERuntime::Args args = {
|
||||
@@ -186,6 +195,7 @@ static void sm100_fp8_fp4_mega_moe(
|
||||
.fast_math = fast_math,
|
||||
.config = config,
|
||||
.y = y.data_ptr(),
|
||||
.cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr,
|
||||
.num_tokens = num_tokens,
|
||||
.sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx),
|
||||
.tensor_map_l1_acts = tensor_map_l1_acts,
|
||||
|
||||
@@ -14,11 +14,13 @@ public:
|
||||
int aligned_batch_size;
|
||||
int split_kv;
|
||||
int num_sms;
|
||||
|
||||
bool is_varlen;
|
||||
|
||||
int batch_size;
|
||||
int next_n;
|
||||
bool is_context_lens_2d;
|
||||
int* context_lens;
|
||||
int* indices;
|
||||
int* schedule_metadata;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
@@ -32,10 +34,10 @@ using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sched::smxx_paged_mqa_logits_metadata<
|
||||
{}, {}, {}
|
||||
{}, {}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", args.aligned_batch_size, args.split_kv, args.num_sms);
|
||||
)", args.aligned_batch_size, args.split_kv, args.num_sms, args.is_varlen ? "true" : "false");
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -44,6 +46,7 @@ static void __instantiate_kernel() {{
|
||||
args.next_n,
|
||||
args.is_context_lens_2d,
|
||||
args.context_lens,
|
||||
args.indices,
|
||||
args.schedule_metadata
|
||||
));
|
||||
}
|
||||
@@ -53,14 +56,15 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
|
||||
const torch::Tensor& schedule_metadata,
|
||||
const int& batch_size, const int& next_n,
|
||||
const int& block_kv, const int& num_sms,
|
||||
const bool& is_context_lens_2d) {
|
||||
const bool& is_context_lens_2d,
|
||||
const bool& is_varlen, const int* indices_ptr) {
|
||||
constexpr int split_kv = 256;
|
||||
constexpr int num_threads = 32;
|
||||
const int aligned_batch_size = align(batch_size, 32);
|
||||
DG_HOST_ASSERT(split_kv % block_kv == 0);
|
||||
|
||||
// Calculate shared memory size
|
||||
const int smem_size = aligned_batch_size * static_cast<int>(sizeof(int));
|
||||
// Shared memory: prefix_sum[kAlignedBatchSize] + varlen_atom_token_start/context_len[kAlignedBatchSize] + varlen_num_atoms
|
||||
const int smem_size = (3 * aligned_batch_size + 1) * static_cast<int>(sizeof(int));
|
||||
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
||||
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
|
||||
|
||||
@@ -69,10 +73,12 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
|
||||
.aligned_batch_size = aligned_batch_size,
|
||||
.split_kv = split_kv,
|
||||
.num_sms = num_sms,
|
||||
.is_varlen = is_varlen,
|
||||
.batch_size = batch_size,
|
||||
.next_n = next_n,
|
||||
.is_context_lens_2d = is_context_lens_2d,
|
||||
.context_lens = context_lens.data_ptr<int>(),
|
||||
.indices = const_cast<int*>(indices_ptr),
|
||||
.schedule_metadata = schedule_metadata.data_ptr<int>(),
|
||||
.launch_args = LaunchArgs(1, num_threads, smem_size)
|
||||
};
|
||||
@@ -90,6 +96,7 @@ public:
|
||||
int head_dim;
|
||||
int block_kv;
|
||||
bool is_context_lens_2d;
|
||||
bool is_varlen;
|
||||
int block_table_stride;
|
||||
int logits_stride;
|
||||
|
||||
@@ -100,6 +107,7 @@ public:
|
||||
int* context_lens;
|
||||
void* logits;
|
||||
int* block_table;
|
||||
int* indices;
|
||||
int* schedule_meta;
|
||||
|
||||
CUtensorMap tensor_map_q;
|
||||
@@ -129,7 +137,7 @@ static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_paged_mqa_logits<
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {},
|
||||
@@ -139,7 +147,7 @@ static void __instantiate_kernel() {{
|
||||
)", arch, arch,
|
||||
args.next_n, args.num_heads,
|
||||
args.head_dim, args.block_kv,
|
||||
args.is_context_lens_2d,
|
||||
args.is_context_lens_2d, args.is_varlen ? "true" : "false",
|
||||
args.num_q_stages, args.num_kv_stages,
|
||||
args.split_kv,
|
||||
args.num_specialized_threads, args.num_math_threads,
|
||||
@@ -151,7 +159,7 @@ static void __instantiate_kernel() {{
|
||||
args.batch_size,
|
||||
args.logits_stride, args.block_table_stride,
|
||||
args.context_lens, args.logits,
|
||||
args.block_table, args.schedule_meta,
|
||||
args.block_table, args.indices, args.schedule_meta,
|
||||
args.tensor_map_q, args.tensor_map_kv,
|
||||
args.tensor_map_kv_scales, args.tensor_map_weights
|
||||
));
|
||||
@@ -165,12 +173,14 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
const torch::Tensor& context_lens,
|
||||
const torch::Tensor& logits,
|
||||
const torch::Tensor& block_table,
|
||||
const torch::Tensor& indices,
|
||||
const torch::Tensor& schedule_meta,
|
||||
const at::ScalarType& logits_dtype,
|
||||
const int& batch_size, const int& next_n,
|
||||
const int& num_heads, const int& head_dim,
|
||||
const int& num_kv_blocks, const int& block_kv,
|
||||
const bool& is_context_lens_2d,
|
||||
const bool& is_varlen,
|
||||
const int& logits_stride,
|
||||
const int& block_table_stride,
|
||||
const int& num_sms,
|
||||
@@ -183,7 +193,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0);
|
||||
|
||||
// Construct TMAs
|
||||
const int next_n_atom = (next_n % 2 == 0) ? 2 : 1;
|
||||
const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1;
|
||||
const auto tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads,
|
||||
head_dim, next_n_atom * num_heads,
|
||||
static_cast<int>(q.stride(2)),
|
||||
@@ -245,6 +255,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
.head_dim = head_dim,
|
||||
.block_kv = block_kv,
|
||||
.is_context_lens_2d = is_context_lens_2d,
|
||||
.is_varlen = is_varlen,
|
||||
.block_table_stride = block_table_stride,
|
||||
.logits_stride = logits_stride,
|
||||
.num_q_stages = num_q_stages,
|
||||
@@ -253,6 +264,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
.context_lens = context_lens.data_ptr<int>(),
|
||||
.logits = logits.data_ptr(),
|
||||
.block_table = block_table.data_ptr<int>(),
|
||||
.indices = is_varlen ? indices.data_ptr<int>() : nullptr,
|
||||
.schedule_meta = schedule_meta.data_ptr<int>(),
|
||||
.tensor_map_q = tensor_map_q,
|
||||
.tensor_map_kv = tensor_map_kv,
|
||||
@@ -279,6 +291,7 @@ public:
|
||||
int head_dim;
|
||||
int block_kv;
|
||||
bool is_context_lens_2d;
|
||||
bool is_varlen;
|
||||
int block_table_stride;
|
||||
int logits_stride;
|
||||
|
||||
@@ -289,6 +302,7 @@ public:
|
||||
int* context_lens;
|
||||
void* logits;
|
||||
int* block_table;
|
||||
int* indices;
|
||||
int* schedule_meta;
|
||||
|
||||
CUtensorMap tensor_map_q;
|
||||
@@ -314,7 +328,7 @@ static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_fp4_paged_mqa_logits<
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {},
|
||||
@@ -323,7 +337,7 @@ static void __instantiate_kernel() {{
|
||||
}};
|
||||
)", args.next_n, args.num_heads,
|
||||
args.head_dim, args.block_kv,
|
||||
args.is_context_lens_2d,
|
||||
args.is_context_lens_2d, args.is_varlen ? "true" : "false",
|
||||
args.num_q_stages, args.num_kv_stages,
|
||||
args.split_kv,
|
||||
args.num_specialized_threads, args.num_math_threads,
|
||||
@@ -335,7 +349,7 @@ static void __instantiate_kernel() {{
|
||||
args.batch_size,
|
||||
args.logits_stride, args.block_table_stride,
|
||||
args.context_lens, args.logits,
|
||||
args.block_table, args.schedule_meta,
|
||||
args.block_table, args.indices, args.schedule_meta,
|
||||
args.tensor_map_q, args.tensor_map_sf_q,
|
||||
args.tensor_map_kv, args.tensor_map_sf_kv,
|
||||
args.tensor_map_weights
|
||||
@@ -351,12 +365,14 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q,
|
||||
const torch::Tensor& context_lens,
|
||||
const torch::Tensor& logits,
|
||||
const torch::Tensor& block_table,
|
||||
const torch::Tensor& indices,
|
||||
const torch::Tensor& schedule_meta,
|
||||
const at::ScalarType& logits_dtype,
|
||||
const int& batch_size, const int& next_n,
|
||||
const int& num_heads, const int& head_dim,
|
||||
const int& num_kv_blocks, const int& block_kv,
|
||||
const bool& is_context_lens_2d,
|
||||
const bool& is_varlen,
|
||||
const int& logits_stride,
|
||||
const int& block_table_stride,
|
||||
const int& num_sms,
|
||||
@@ -366,8 +382,8 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q,
|
||||
DG_HOST_ASSERT(split_kv == 256 and logits_stride % split_kv == 0);
|
||||
|
||||
// TODO: tuning num_stages
|
||||
const int num_q_stages = 3, num_kv_stages = 6, num_tmem_stages = 3;
|
||||
const int next_n_atom = (next_n % 2 == 0) ? 2 : 1;
|
||||
const int num_q_stages = 3, num_kv_stages = 10, num_tmem_stages = 3;
|
||||
const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1;
|
||||
|
||||
// `head_dim` must be 128 for 64B swizzling
|
||||
DG_HOST_ASSERT(head_dim == 128);
|
||||
@@ -416,6 +432,7 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q,
|
||||
.head_dim = head_dim,
|
||||
.block_kv = block_kv,
|
||||
.is_context_lens_2d = is_context_lens_2d,
|
||||
.is_varlen = is_varlen,
|
||||
.block_table_stride = block_table_stride,
|
||||
.logits_stride = logits_stride,
|
||||
.num_q_stages = num_q_stages,
|
||||
@@ -424,6 +441,7 @@ static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q,
|
||||
.context_lens = context_lens.data_ptr<int>(),
|
||||
.logits = logits.data_ptr(),
|
||||
.block_table = block_table.data_ptr<int>(),
|
||||
.indices = is_varlen ? indices.data_ptr<int>() : nullptr,
|
||||
.schedule_meta = schedule_meta.data_ptr<int>(),
|
||||
.tensor_map_q = tensor_map_q,
|
||||
.tensor_map_sf_q = tensor_map_sf_q,
|
||||
|
||||
@@ -123,4 +123,4 @@ _C.init(
|
||||
_find_cuda_home() # CUDA home
|
||||
)
|
||||
|
||||
__version__ = '2.4.2'
|
||||
__version__ = '2.5.0'
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
|
||||
#include <deep_gemm/ptx/ld_st.cuh>
|
||||
#include <deep_gemm/layout/sym_buffer.cuh>
|
||||
#include <deep_gemm/layout/mega_moe.cuh>
|
||||
|
||||
namespace deep_gemm::comm {
|
||||
|
||||
CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() {
|
||||
// Perform cluster_sync with `barrier.cluster.arrive.relaxed`
|
||||
// This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee
|
||||
cute::cluster_arrive_relaxed();
|
||||
cute::cluster_wait();
|
||||
}
|
||||
|
||||
template <uint32_t kNumSMs, uint32_t kGridSyncIndex = 0, typename sync_scope_t>
|
||||
CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace,
|
||||
const uint32_t& sm_idx, const uint32_t& thread_idx,
|
||||
|
||||
@@ -401,7 +401,8 @@ void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
// Load accumulator from TMEM
|
||||
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
|
||||
tmem_load(cute::Int<kNumHeads>{}, tmem_addr, accum);
|
||||
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
|
||||
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
|
||||
|
||||
// Release TMEM empty
|
||||
if (i == BLOCK_Q - 1) {
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace deep_gemm {
|
||||
|
||||
template <uint32_t kNextN, uint32_t kNumHeads,
|
||||
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
||||
bool kIsContextLens2D,
|
||||
bool kIsContextLens2D, bool kIsVarlen,
|
||||
uint32_t kNumQStages, uint32_t kNumKVStages,
|
||||
uint32_t SPLIT_KV,
|
||||
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
||||
@@ -30,7 +30,8 @@ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
|
||||
void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
const uint32_t logits_stride, const uint32_t block_table_stride,
|
||||
const uint32_t* context_lens, logits_dtype_t* logits,
|
||||
const uint32_t* block_table, const uint32_t* schedule_meta,
|
||||
const uint32_t* block_table, const uint32_t* indices,
|
||||
const uint32_t* schedule_meta,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
||||
@@ -54,10 +55,10 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
|
||||
}
|
||||
|
||||
// Next-N atom configs
|
||||
static constexpr uint32_t kNextNAtom = (kNextN % 2 == 0) ? 2 : 1;
|
||||
static constexpr uint32_t kNumNextNAtoms = kNextN / kNextNAtom;
|
||||
static constexpr bool kSingleAtom = (kNumNextNAtoms == 1);
|
||||
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
|
||||
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
|
||||
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
|
||||
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
|
||||
|
||||
// UMMA configs
|
||||
static constexpr uint32_t kNumTmemStages = 3;
|
||||
@@ -157,7 +158,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
|
||||
// Scheduler
|
||||
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
|
||||
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
|
||||
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
|
||||
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
|
||||
|
||||
// Make Q, KV and TMEM pipeline
|
||||
@@ -182,7 +183,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
|
||||
// Persistently schedule over blocks
|
||||
// Initialize outside valid range to indicate no previous task
|
||||
@@ -196,11 +197,12 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
||||
|
||||
// Issue TMA Q
|
||||
const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx);
|
||||
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
|
||||
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
||||
smem_q[q_stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads);
|
||||
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_atom_idx * kNextNAtom);
|
||||
tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_atom_idx * kNextNAtom);
|
||||
smem_q[q_stage_idx], 0, q_token_idx * kNumHeads);
|
||||
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx);
|
||||
tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx);
|
||||
full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
|
||||
}
|
||||
last_q_atom_idx = q_atom_idx;
|
||||
@@ -210,7 +212,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
} else if (warp_idx == kSpecWarpStart + 1) {
|
||||
// TMA warp for loading KV cache
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
|
||||
// Persistently schedule over blocks
|
||||
uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage;
|
||||
@@ -225,10 +227,11 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
// Coalesced load of block table
|
||||
if (kv_block_idx_ptr == 32) {
|
||||
kv_block_idx_ptr = 0;
|
||||
const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast<uint64_t>(block_table_stride);
|
||||
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
|
||||
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
|
||||
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Broadcast KV block indices
|
||||
int kv_block_idx[kNumBlocksPerSplit];
|
||||
@@ -240,7 +243,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
|
||||
// Wait KV consumer release
|
||||
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
|
||||
|
||||
|
||||
// Issue TMA KV
|
||||
if (cute::elect_one_sync()) {
|
||||
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
||||
@@ -260,7 +263,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
} else if (warp_idx == kSpecWarpStart + 2) {
|
||||
// UMMA warp
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
|
||||
|
||||
// UTCCP transposer
|
||||
@@ -371,7 +374,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
} else if (warp_idx < kSpecWarpStart) {
|
||||
// Math warpgroups for reduce
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
|
||||
const auto math_warpgroup_idx = warpgroup_idx;
|
||||
const auto math_thread_idx = warp_idx * 32 + lane_idx;
|
||||
@@ -400,6 +403,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
// Persistently schedule over blocks
|
||||
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
|
||||
uint32_t q_atom_idx, kv_idx, _;
|
||||
bool is_paired_atom = false;
|
||||
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
|
||||
if (q_atom_idx != last_q_atom_idx) {
|
||||
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
|
||||
@@ -423,11 +427,16 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
weights[i][j + 3] = raw.w;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this atom pairs two tokens from the same sequence
|
||||
if constexpr (kIsVarlen) {
|
||||
is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2);
|
||||
}
|
||||
}
|
||||
last_q_atom_idx = q_atom_idx;
|
||||
|
||||
// Calculate KV offset in advance
|
||||
auto kv_offset = q_atom_idx * kNextNAtom * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
|
||||
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
|
||||
|
||||
// Advance pipeline by `kNumMathWarpGroups` steps
|
||||
// Wait UMMA arrival
|
||||
@@ -436,53 +445,58 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
ptx::tcgen05_after_thread_sync();
|
||||
|
||||
// Reduce over the head dim and store
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
|
||||
// Load accumulator from TMEM
|
||||
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
|
||||
tmem_load(cute::Int<kNumHeads>{}, tmem_addr, accum);
|
||||
const auto reduce_and_store = [&](auto num_iters_c) {
|
||||
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
|
||||
|
||||
// Only loop over valid iterations
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumIters; ++ i) {
|
||||
// Load accumulator from TMEM
|
||||
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
|
||||
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
|
||||
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
|
||||
|
||||
// Accumulate weighted ReLU in parallel
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
|
||||
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
||||
return __ffma2_rn(a, b, sum);
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
||||
sum_0 = transform(j, sum_0);
|
||||
sum_1 = transform(j + 2, sum_1);
|
||||
}
|
||||
|
||||
auto sum = __fadd2_rn(sum_0, sum_1);
|
||||
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
|
||||
|
||||
// Store into the global memory
|
||||
logits[kv_offset + i * static_cast<uint64_t>(logits_stride)] = result;
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Release TMEM empty
|
||||
if (i == kNextNAtom - 1) {
|
||||
ptx::tcgen05_before_thread_sync();
|
||||
empty_tmem_barriers[tmem_stage_idx]->arrive();
|
||||
}
|
||||
ptx::tcgen05_before_thread_sync();
|
||||
empty_tmem_barriers[tmem_stage_idx]->arrive();
|
||||
};
|
||||
|
||||
// Accumulate weighted ReLU in parallel
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
|
||||
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
||||
return __ffma2_rn(a, b, sum);
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
||||
sum_0 = transform(j, sum_0);
|
||||
sum_1 = transform(j + 2, sum_1);
|
||||
}
|
||||
|
||||
auto sum = __fadd2_rn(sum_0, sum_1);
|
||||
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
|
||||
|
||||
// Store into the global memory
|
||||
const auto dst_offset = kv_offset + i * static_cast<uint64_t>(logits_stride);
|
||||
if constexpr(sizeof(logits_dtype_t) == 2) {
|
||||
// Pack two adjacent bf16 lanes into uint32 for wider store
|
||||
uint16_t my_bits = *reinterpret_cast<const uint16_t*>(&result);
|
||||
uint16_t neighbor_bits = __shfl_down_sync(0xffffffff, my_bits, 1);
|
||||
uint32_t packed;
|
||||
asm volatile("mov.b32 %0, {%1, %2};" : "=r"(packed) : "h"(my_bits), "h"(neighbor_bits));
|
||||
if (lane_idx % 2 == 0)
|
||||
*reinterpret_cast<uint32_t*>(logits + dst_offset) = packed;
|
||||
} else {
|
||||
logits[dst_offset] = result;
|
||||
}
|
||||
// this sync warp prevent the next load tmem from reordering
|
||||
// nvcc may reorder it to overlap with the current tmem load, lead to large register usage
|
||||
__syncwarp();
|
||||
if constexpr (kIsVarlen) {
|
||||
if (is_paired_atom)
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
else
|
||||
reduce_and_store(cute::Int<1>{});
|
||||
} else if constexpr (kPadOddN) {
|
||||
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
|
||||
reduce_and_store(cute::Int<1>{});
|
||||
else
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
} else {
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ template <
|
||||
>
|
||||
CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void
|
||||
sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
const uint32_t num_tokens,
|
||||
const __grid_constant__ layout::SymBuffer<kNumRanks> sym_buffer,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts,
|
||||
@@ -91,7 +92,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
|
||||
// Workspaces
|
||||
const auto workspace = layout::Workspace(
|
||||
sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk, BLOCK_M);
|
||||
sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk);
|
||||
|
||||
// Token and buffer layouts
|
||||
constexpr auto fp8_token_layout = layout::Data(kHidden);
|
||||
@@ -170,7 +171,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
constexpr uint32_t UMMA_K = 32;
|
||||
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A
|
||||
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N;
|
||||
DG_STATIC_ASSERT(BLOCK_M % 32 == 0, "Invalid block M");
|
||||
DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M");
|
||||
DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N");
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
|
||||
|
||||
@@ -269,7 +270,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2);
|
||||
|
||||
// A cluster sync is essential for 2CTA tensor memory allocation
|
||||
cute::cluster_sync();
|
||||
comm::cluster_sync_with_relaxed_arrive();
|
||||
|
||||
// Initialization
|
||||
if (warp_idx == 0) {
|
||||
@@ -307,7 +308,9 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
// Allocate tensor memory
|
||||
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
cute::cluster_sync();
|
||||
// NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`,
|
||||
// and `barrier.cluster.wait.aligned` is by default `.acquire`
|
||||
comm::cluster_sync_with_relaxed_arrive();
|
||||
|
||||
// Task scheduler
|
||||
auto scheduler = sched::MegaMoEScheduler<
|
||||
@@ -599,7 +602,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Clean workspace for the next usage
|
||||
// Clean workspace for the next usage, and also do cumulative stats
|
||||
// NOTES: it is overlapped with combine reduction epilogue
|
||||
ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
|
||||
|
||||
@@ -623,19 +626,27 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
// Wait read count ready
|
||||
ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
|
||||
|
||||
// Clean expert token count
|
||||
if (thread_idx == 0)
|
||||
// Clean expert token count, and add cumulative results
|
||||
DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps");
|
||||
if (warp_idx == 0) {
|
||||
*workspace.get_expert_recv_count_sum_ptr(i) = 0;
|
||||
} else if (warp_idx == 1) {
|
||||
if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr)
|
||||
ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast<int>(num_recv_tokens));
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Clean per-rank token count
|
||||
for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads)
|
||||
*workspace.get_expert_recv_count_ptr(j, i) = 0;
|
||||
__syncwarp();
|
||||
|
||||
// Clean L1 and L2 arrival stuffs
|
||||
for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) {
|
||||
*workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0;
|
||||
*workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -672,23 +683,22 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx);
|
||||
const auto expected = scheduler.template get_valid_m<false>();
|
||||
while (ptx::ld_acq(ptr) != expected);
|
||||
} else {
|
||||
// The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival
|
||||
// NOTES: Originally we wait blocks on-demand to overlap L1 calculation
|
||||
// with L2, but this optimization is negative when `num_experts_per_wave`
|
||||
// guarantees L1's completion when L2 starts. So we remove it.
|
||||
// In the future, if `num_experts_per_wave` is not large enough
|
||||
// due to small `num_experts_per_rank`, we may need to add it back or add a switch
|
||||
DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes");
|
||||
const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx);
|
||||
// NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts
|
||||
// to avoid undefined behavior when `num_k_blocks == 32`
|
||||
const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1;
|
||||
while (ptx::ld_acq_gpu(ptr) != expected);
|
||||
}
|
||||
|
||||
uint64_t cached_l2_arrival_mask = 0;
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) {
|
||||
// Wait current K block arrival
|
||||
if (block_phase == sched::BlockPhase::Linear2) {
|
||||
// The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2 L1 blocks' arrival
|
||||
DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes");
|
||||
const uint64_t needed = 3ull << (k_block_idx * 2);
|
||||
if ((cached_l2_arrival_mask & needed) != needed) {
|
||||
const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx);
|
||||
do {
|
||||
cached_l2_arrival_mask = ptx::ld_acq_gpu(ptr);
|
||||
} while ((cached_l2_arrival_mask & needed) != needed);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait consumer release
|
||||
empty_barriers[stage_idx]->wait(phase ^ 1);
|
||||
|
||||
@@ -953,8 +963,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
|
||||
// Load weights from global into register cache per 32 tokens
|
||||
DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size");
|
||||
DG_STATIC_ASSERT(WG_BLOCK_M % 32 == 0, "Invalid block size");
|
||||
if ((j * ATOM_M) % 32 == 0) {
|
||||
if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) {
|
||||
stored_cached_weight = *l1_topk_weights_buffer
|
||||
.get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx)
|
||||
.get_base_ptr<float>();
|
||||
@@ -1060,19 +1069,26 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
// Only one warp per pair writes (both hold the same SF after cross-warp reduce)
|
||||
// Each lane < 4 holds SF for 2 rows (sf.x and sf.y)
|
||||
if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {
|
||||
// TODO: I believe the expression can be optimized
|
||||
const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M
|
||||
+ epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2;
|
||||
const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2;
|
||||
const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4;
|
||||
const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t);
|
||||
const auto sf_base_ptr = l2_sf_buffer.get_base_ptr<uint8_t>();
|
||||
// NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4
|
||||
// NOTES: originally there was:
|
||||
// - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2
|
||||
// - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)`
|
||||
// We find out that
|
||||
// 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside
|
||||
// 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside
|
||||
// This reduce the number of computation instructions.
|
||||
const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
|
||||
__builtin_assume(token_base_idx < BLOCK_M);
|
||||
const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
|
||||
+ transform_sf_token_idx(token_idx_in_expert);
|
||||
sf_base_ptr[k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx] =
|
||||
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
|
||||
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx;
|
||||
sf_base_ptr[sf_addr] =
|
||||
(*reinterpret_cast<const uint32_t*>(&sf.x) >> 23);
|
||||
sf_base_ptr[k_uint_idx * mn_stride + (sf_pool_token_idx + 4) * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx] =
|
||||
sf_base_ptr[sf_addr + 4 * static_cast<uint32_t>(sizeof(uint32_t))] =
|
||||
(*reinterpret_cast<const uint32_t*>(&sf.y) >> 23);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace deep_gemm {
|
||||
|
||||
template <uint32_t kNextN, uint32_t kNumHeads,
|
||||
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
||||
bool kIsContextLens2D,
|
||||
bool kIsContextLens2D, bool kIsVarlen,
|
||||
uint32_t kNumQStages, uint32_t kNumKVStages,
|
||||
uint32_t SPLIT_KV,
|
||||
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
||||
@@ -30,7 +30,8 @@ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
|
||||
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
const uint32_t logits_stride, const uint32_t block_table_stride,
|
||||
const uint32_t* context_lens, logits_dtype_t* logits,
|
||||
const uint32_t* block_table, const uint32_t* schedule_meta,
|
||||
const uint32_t* block_table, const uint32_t* indices,
|
||||
const uint32_t* schedule_meta,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
||||
@@ -53,10 +54,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
||||
}
|
||||
|
||||
// Next-N atom configs
|
||||
static constexpr uint32_t kNextNAtom = (kNextN % 2 == 0) ? 2 : 1;
|
||||
static constexpr uint32_t kNumNextNAtoms = kNextN / kNextNAtom;
|
||||
static constexpr bool kSingleAtom = (kNumNextNAtoms == 1);
|
||||
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
|
||||
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
|
||||
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
|
||||
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
|
||||
|
||||
// Shared memory configs
|
||||
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
||||
@@ -136,7 +137,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
|
||||
// Scheduler
|
||||
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
|
||||
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
|
||||
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
|
||||
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
|
||||
|
||||
// Q and KV pipeline
|
||||
@@ -157,13 +158,14 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
if (warp_idx == kSpecWarpStart) {
|
||||
// TMA warp for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
||||
|
||||
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_atom_idx) {
|
||||
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) {
|
||||
if (cute::elect_one_sync()) {
|
||||
tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads);
|
||||
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_atom_idx * kNextNAtom);
|
||||
const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx);
|
||||
tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads);
|
||||
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx);
|
||||
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
||||
}
|
||||
};
|
||||
@@ -182,7 +184,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
|
||||
while (fetched_next_task) {
|
||||
// Prefetch next Q when (q, atom) changes
|
||||
bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + 1);
|
||||
const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size);
|
||||
bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance);
|
||||
|
||||
if (q_atom_idx != next_q_atom_idx)
|
||||
kv_block_idx_ptr = 32;
|
||||
@@ -195,17 +198,18 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
// TODO(xuzhean): consider -1
|
||||
if (kv_block_idx_ptr == 32) {
|
||||
kv_block_idx_ptr = 0;
|
||||
const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast<uint64_t>(block_table_stride);
|
||||
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
|
||||
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
|
||||
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
|
||||
}
|
||||
__syncwarp();
|
||||
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
|
||||
|
||||
// Wait Q consumer release and issue TMA Q
|
||||
if (prefetch_q) {
|
||||
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
||||
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
||||
issue_tma_q(q_stage_idx, q_atom_idx + 1);
|
||||
issue_tma_q(q_stage_idx, q_atom_idx + next_advance);
|
||||
}
|
||||
|
||||
uint32_t kv_block_idx[kNumBlocksPerSplit];
|
||||
@@ -236,7 +240,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
}
|
||||
} else if (warp_idx == kSpecWarpStart + 1) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
||||
|
||||
// Require full allocation
|
||||
@@ -292,7 +296,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
} else if (warp_idx < kSpecWarpStart) {
|
||||
// Math warpgroups for reduce
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
||||
|
||||
// Offsets
|
||||
@@ -321,6 +325,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
|
||||
uint32_t q_stage_idx, q_phase;
|
||||
uint32_t umma_phase = 0;
|
||||
bool is_paired_atom = false;
|
||||
|
||||
while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
|
||||
// Q or atom changes
|
||||
@@ -340,6 +345,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
for (uint32_t j = 0; j < kNumHeads; ++ j)
|
||||
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
|
||||
}
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2);
|
||||
}
|
||||
}
|
||||
|
||||
// Get current task indices
|
||||
@@ -347,7 +356,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
kv_idx = next_kv_idx;
|
||||
|
||||
// Calculate KV offset in advance
|
||||
auto kv_offset = q_atom_idx * kNextNAtom * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
|
||||
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
|
||||
|
||||
// Wait TMA KV arrival
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
||||
@@ -367,40 +376,56 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
// Reduce over the head dim and store
|
||||
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
|
||||
// Load accumulator from TMEM
|
||||
const auto reduce_and_store = [&](auto num_iters_c) {
|
||||
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
|
||||
float accum[kNumHeads];
|
||||
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
|
||||
|
||||
// Release TMEM empty
|
||||
if (i == kNextNAtom - 1) {
|
||||
ptx::tcgen05_before_thread_sync();
|
||||
empty_umma_barriers[math_warpgroup_idx]->arrive();
|
||||
}
|
||||
|
||||
// Accumulate weighted ReLU in parallel
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
|
||||
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
||||
return __ffma2_rn(a, b, sum);
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
||||
sum_0 = transform(j, sum_0);
|
||||
sum_1 = transform(j + 2, sum_1);
|
||||
for (uint32_t i = 0; i < kNumIters; ++ i) {
|
||||
// Load accumulator from TMEM
|
||||
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
|
||||
|
||||
// Accumulate weighted ReLU in parallel
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
|
||||
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
||||
return __ffma2_rn(a, b, sum);
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
||||
sum_0 = transform(j, sum_0);
|
||||
sum_1 = transform(j + 2, sum_1);
|
||||
}
|
||||
|
||||
auto sum = __fadd2_rn(sum_0, sum_1);
|
||||
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
|
||||
|
||||
// Store into the global memory
|
||||
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
auto sum = __fadd2_rn(sum_0, sum_1);
|
||||
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
|
||||
// Release TMEM empty
|
||||
ptx::tcgen05_before_thread_sync();
|
||||
empty_umma_barriers[math_warpgroup_idx]->arrive();
|
||||
};
|
||||
|
||||
// Store into the global memory
|
||||
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
|
||||
__syncwarp();
|
||||
if constexpr (kIsVarlen) {
|
||||
if (is_paired_atom)
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
else
|
||||
reduce_and_store(cute::Int<1>{});
|
||||
} else if constexpr (kPadOddN) {
|
||||
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
|
||||
reduce_and_store(cute::Int<1>{});
|
||||
else
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
} else {
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace deep_gemm {
|
||||
|
||||
template <uint32_t kNextN, uint32_t kNumHeads,
|
||||
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
||||
bool kIsContextLens2D,
|
||||
bool kIsContextLens2D, bool kIsVarlen,
|
||||
uint32_t kNumQStages, uint32_t kNumKVStages,
|
||||
uint32_t SPLIT_KV,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
||||
@@ -30,11 +30,14 @@ CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
||||
void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
const uint32_t logits_stride, const uint32_t block_table_stride,
|
||||
const uint32_t* context_lens, logits_dtype_t* logits,
|
||||
const uint32_t* block_table, const uint32_t* schedule_meta,
|
||||
const uint32_t* block_table, const uint32_t* indices,
|
||||
const uint32_t* schedule_meta,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
||||
DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename mma::sm90::FP8MMASelector<kNextN * kNumHeads>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
@@ -132,8 +135,8 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
cudaGridDependencySynchronize();
|
||||
|
||||
// Scheduler
|
||||
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups, 1>(
|
||||
blockIdx.x, context_lens, schedule_meta);
|
||||
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumMathWarpGroups, 1>(
|
||||
blockIdx.x, batch_size, context_lens, schedule_meta, indices);
|
||||
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
|
||||
|
||||
// Q and KV pipeline
|
||||
|
||||
@@ -1,19 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/numeric/math.hpp>
|
||||
|
||||
#include <deep_gemm/common/math.cuh>
|
||||
#include <deep_gemm/common/exception.cuh>
|
||||
|
||||
namespace deep_gemm::layout {
|
||||
|
||||
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding
|
||||
static constexpr int kNumCandidateBlockMs = 7;
|
||||
static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192};
|
||||
static constexpr int kMaxCandidateBlockM = 192;
|
||||
static constexpr int kMinCandidateBlockM = 8;
|
||||
static constexpr int kLCMCandidateBlockM = 384;
|
||||
|
||||
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk,
|
||||
T num_experts_per_rank, T block_m) {
|
||||
T num_experts_per_rank) {
|
||||
const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank;
|
||||
const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank);
|
||||
return math::constexpr_align(
|
||||
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (block_m - 1),
|
||||
block_m);
|
||||
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast<T>(kMaxCandidateBlockM) - 1),
|
||||
static_cast<T>(kLCMCandidateBlockM));
|
||||
}
|
||||
|
||||
// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M
|
||||
@@ -48,17 +56,14 @@ struct Workspace {
|
||||
const uint32_t& num_ranks,
|
||||
const uint32_t& num_experts,
|
||||
const uint32_t& num_max_tokens_per_rank,
|
||||
const uint32_t& num_topk,
|
||||
const uint32_t& block_m):
|
||||
const uint32_t& num_topk):
|
||||
base(base),
|
||||
num_ranks(num_ranks), num_experts(num_experts),
|
||||
num_max_tokens_per_rank(num_max_tokens_per_rank) {
|
||||
num_experts_per_rank = num_experts / num_ranks;
|
||||
num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank;
|
||||
num_max_pool_tokens = get_num_max_pool_tokens(
|
||||
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank, block_m);
|
||||
num_max_pool_blocks = num_max_pool_tokens / block_m;
|
||||
DG_UNIFIED_ASSERT(num_max_tokens_per_rank % block_m == 0);
|
||||
num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
|
||||
num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@@ -164,7 +164,7 @@ CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) {
|
||||
asm volatile("st.L1::no_allocate.relaxed.sys.u64 [%0], %1;" :: "l"(ptr), "l"(value));
|
||||
asm volatile("st.L1::no_allocate.relaxed.sys.global.u64 [%0], %1;" :: "l"(ptr), "l"(value));
|
||||
}
|
||||
|
||||
/// Atomics
|
||||
@@ -186,7 +186,11 @@ CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& valu
|
||||
return ret;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void red_add(const uint32_t* ptr, const uint32_t& value) {
|
||||
CUTLASS_DEVICE void red_add(const int* ptr, const int& value) {
|
||||
asm volatile("red.gpu.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void red_add(const uint32_t* ptr, const uint32_t& value) {
|
||||
asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value));
|
||||
}
|
||||
|
||||
|
||||
@@ -6,22 +6,51 @@
|
||||
|
||||
namespace deep_gemm::sched {
|
||||
|
||||
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
|
||||
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs, bool kIsVarlen = false>
|
||||
CUTLASS_GLOBAL __launch_bounds__(32, 1)
|
||||
void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d,
|
||||
const uint32_t* context_lens, uint32_t* schedule_metadata) {
|
||||
const uint32_t* context_lens, const uint32_t* indices, uint32_t* schedule_metadata) {
|
||||
DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
|
||||
const uint32_t lane_idx = ptx::get_lane_idx();
|
||||
|
||||
// Wait for primary kernel completion
|
||||
cudaGridDependencySynchronize();
|
||||
|
||||
__shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize];
|
||||
__shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize];
|
||||
__shared__ uint32_t varlen_num_atoms_shared;
|
||||
uint32_t num_items;
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
if (lane_idx == 0) {
|
||||
uint32_t t = 0, atom_count = 0;
|
||||
while (t < batch_size) {
|
||||
varlen_atom_token_start[atom_count] = t;
|
||||
const bool is_paired = (t + 1 < batch_size and indices[t] == indices[t + 1]);
|
||||
varlen_atom_context_len[atom_count] = is_paired ? context_lens[t + 1] : context_lens[t];
|
||||
t += is_paired ? 2 : 1;
|
||||
++ atom_count;
|
||||
}
|
||||
varlen_num_atoms_shared = atom_count;
|
||||
}
|
||||
__syncwarp();
|
||||
num_items = varlen_num_atoms_shared;
|
||||
} else {
|
||||
num_items = batch_size;
|
||||
}
|
||||
|
||||
// Compute num_segs and prefix sum
|
||||
uint32_t num_segs[kAlignedBatchSize / 32];
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
|
||||
const uint32_t q_idx = k * 32 + lane_idx;
|
||||
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
|
||||
const uint32_t context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0);
|
||||
uint32_t context_len;
|
||||
if constexpr (kIsVarlen) {
|
||||
context_len = (q_idx < num_items ? varlen_atom_context_len[q_idx] : 0);
|
||||
} else {
|
||||
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
|
||||
context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0);
|
||||
}
|
||||
num_segs[k] = math::ceil_div(context_len, SPLIT_KV);
|
||||
}
|
||||
|
||||
@@ -40,44 +69,118 @@ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t ne
|
||||
sum = __shfl_sync(0xffffffff, x, 31);
|
||||
}
|
||||
|
||||
const uint32_t num_next_n_atoms = next_n / ((next_n % 2 == 0) ? 2 : 1);
|
||||
const uint32_t total = sum * num_next_n_atoms;
|
||||
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
|
||||
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
|
||||
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
|
||||
uint32_t q_idx = 0;
|
||||
while (q_idx < batch_size and prefix_sum[q_idx] * num_next_n_atoms <= seg_starts)
|
||||
++ q_idx;
|
||||
const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms);
|
||||
const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]);
|
||||
const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0;
|
||||
const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0;
|
||||
const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx;
|
||||
__syncwarp();
|
||||
// SM work distribution
|
||||
if constexpr (kIsVarlen) {
|
||||
const uint32_t total = sum;
|
||||
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
|
||||
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
|
||||
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
|
||||
uint32_t lo = 0, hi = num_items;
|
||||
while (lo < hi) {
|
||||
const uint32_t mid = (lo + hi) / 2;
|
||||
const bool pred = prefix_sum[mid] <= seg_starts;
|
||||
lo = pred ? mid + 1 : lo;
|
||||
hi = pred ? hi : mid;
|
||||
}
|
||||
const uint32_t atom_idx = lo;
|
||||
const uint32_t kv_split_idx = (atom_idx == 0 ? seg_starts : seg_starts - prefix_sum[atom_idx - 1]);
|
||||
const uint32_t q_atom_idx = (atom_idx < num_items ? varlen_atom_token_start[atom_idx] : batch_size);
|
||||
__syncwarp();
|
||||
|
||||
schedule_metadata[sm_idx * 2] = q_atom_idx;
|
||||
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
|
||||
schedule_metadata[sm_idx * 2] = q_atom_idx;
|
||||
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
|
||||
}
|
||||
} else {
|
||||
const uint32_t next_n_atom = (next_n >= 2) ? 2 : 1;
|
||||
const uint32_t num_next_n_atoms = math::ceil_div(next_n, next_n_atom);
|
||||
const uint32_t total = sum * num_next_n_atoms;
|
||||
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
|
||||
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
|
||||
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
|
||||
uint32_t lo = 0, hi = batch_size;
|
||||
while (lo < hi) {
|
||||
const uint32_t mid = (lo + hi) / 2;
|
||||
const bool pred = prefix_sum[mid] * num_next_n_atoms <= seg_starts;
|
||||
lo = pred ? mid + 1 : lo;
|
||||
hi = pred ? hi : mid;
|
||||
}
|
||||
const uint32_t q_idx = lo;
|
||||
const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms);
|
||||
const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]);
|
||||
const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0;
|
||||
const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0;
|
||||
const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx;
|
||||
__syncwarp();
|
||||
|
||||
schedule_metadata[sm_idx * 2] = q_atom_idx;
|
||||
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t kNextN, bool kIsContextLens2D,
|
||||
// Conditional storage for varlen indices pointer (EBO: zero cost when unused)
|
||||
template <bool kHasIndices>
|
||||
struct IndicesStorage {
|
||||
const uint32_t* indices;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IndicesStorage<false> {};
|
||||
|
||||
template <uint32_t kNextN, bool kIsContextLens2D, bool kIsVarlen,
|
||||
uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit,
|
||||
uint32_t kNumNextNAtoms>
|
||||
struct PagedMQALogitsScheduler {
|
||||
struct PagedMQALogitsScheduler : IndicesStorage<kIsVarlen> {
|
||||
const uint32_t* context_lens;
|
||||
uint32_t batch_size;
|
||||
|
||||
uint32_t current_q_atom_idx, current_kv_idx;
|
||||
uint32_t end_q_atom_idx, end_kv_idx;
|
||||
uint32_t current_num_kv;
|
||||
|
||||
CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const {
|
||||
const uint32_t q_idx = q_atom_idx / kNumNextNAtoms;
|
||||
const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
|
||||
return math::ceil_div(context_lens[lens_idx], BLOCK_KV);
|
||||
CUTLASS_DEVICE static uint32_t atom_to_token_idx(const uint32_t& q_atom_idx) {
|
||||
if constexpr (kIsVarlen) {
|
||||
return q_atom_idx;
|
||||
} else {
|
||||
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
|
||||
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
|
||||
if constexpr (kPadOddN) {
|
||||
return q_atom_idx / kNumNextNAtoms * kNextN + q_atom_idx % kNumNextNAtoms * kNextNAtom;
|
||||
} else {
|
||||
return q_atom_idx * kNextNAtom;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t* context_lens, const uint32_t* schedule_meta) {
|
||||
CUTLASS_DEVICE static uint32_t atom_to_block_table_row(const uint32_t& q_atom_idx) {
|
||||
if constexpr (kIsVarlen) {
|
||||
return q_atom_idx;
|
||||
} else {
|
||||
return q_atom_idx / kNumNextNAtoms;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const {
|
||||
if constexpr (kIsVarlen) {
|
||||
const bool is_paired = (q_atom_idx + 1 < batch_size and
|
||||
this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]);
|
||||
const uint32_t ctx_len = is_paired ? context_lens[q_atom_idx + 1] : context_lens[q_atom_idx];
|
||||
return math::ceil_div(ctx_len, BLOCK_KV);
|
||||
} else {
|
||||
const uint32_t q_idx = q_atom_idx / kNumNextNAtoms;
|
||||
const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
|
||||
return math::ceil_div(context_lens[lens_idx], BLOCK_KV);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t& batch_size,
|
||||
const uint32_t* context_lens,
|
||||
const uint32_t* schedule_meta, const uint32_t* indices) {
|
||||
this->context_lens = context_lens;
|
||||
this->batch_size = batch_size;
|
||||
if constexpr (kIsVarlen) {
|
||||
this->indices = indices;
|
||||
}
|
||||
|
||||
const auto current_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx];
|
||||
const auto end_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx + 1];
|
||||
@@ -87,6 +190,28 @@ struct PagedMQALogitsScheduler {
|
||||
current_num_kv = get_num_kv(current_q_atom_idx);
|
||||
}
|
||||
|
||||
// Advance step in q_atom_idx space when moving to the next atom.
|
||||
// Varlen: 1 or 2 depending on whether consecutive tokens share the same sequence.
|
||||
// Non-varlen: always 1 (one atom unit).
|
||||
CUTLASS_DEVICE uint32_t get_atom_advance(const uint32_t& q_atom_idx, const uint32_t& bound) const {
|
||||
if constexpr (kIsVarlen) {
|
||||
return (q_atom_idx + 1 < bound and this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]) ? 2 : 1;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Whether num_kv should be refreshed after advancing to q_atom_idx.
|
||||
// Varlen: always refresh (each atom may have a different context_len).
|
||||
// Non-varlen: only at atom-group boundaries (atoms within a group share context_len).
|
||||
CUTLASS_DEVICE bool should_refresh_num_kv(const uint32_t& q_atom_idx) const {
|
||||
if constexpr (kIsVarlen) {
|
||||
return true;
|
||||
} else {
|
||||
return q_atom_idx % kNumNextNAtoms == 0;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) {
|
||||
q_atom_idx = current_q_atom_idx;
|
||||
kv_idx = current_kv_idx;
|
||||
@@ -97,9 +222,9 @@ struct PagedMQALogitsScheduler {
|
||||
|
||||
current_kv_idx += kNumBlocksPerSplit;
|
||||
if (current_kv_idx >= current_num_kv) {
|
||||
++ current_q_atom_idx;
|
||||
current_kv_idx = 0;
|
||||
if (current_q_atom_idx % kNumNextNAtoms == 0 and exist_q_atom_idx(current_q_atom_idx)) {
|
||||
current_q_atom_idx += get_atom_advance(current_q_atom_idx, end_q_atom_idx);
|
||||
if (should_refresh_num_kv(current_q_atom_idx) and exist_q_atom_idx(current_q_atom_idx)) {
|
||||
current_num_kv = get_num_kv(current_q_atom_idx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,10 +61,8 @@ def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup,
|
||||
hidden: int, intermediate_hidden: int,
|
||||
use_fp8_dispatch: bool = True,
|
||||
activation: str = 'swiglu') -> SymmBuffer:
|
||||
# Token count must be aligned to block m
|
||||
num_ranks = group.size()
|
||||
block_m = _C.get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk)
|
||||
num_max_tokens_per_rank = align(num_max_tokens_per_rank, block_m)
|
||||
# Token count must be aligned to block sizes
|
||||
num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe())
|
||||
|
||||
return SymmBuffer(
|
||||
group, num_experts,
|
||||
@@ -111,6 +109,7 @@ def fp8_fp4_mega_moe(y: torch.Tensor,
|
||||
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
sym_buffer: SymmBuffer,
|
||||
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
|
||||
recipe: Tuple[int, int, int] = (1, 1, 32),
|
||||
activation: str = 'swiglu',
|
||||
activation_clamp: Optional[float] = None,
|
||||
@@ -118,6 +117,7 @@ def fp8_fp4_mega_moe(y: torch.Tensor,
|
||||
_C.fp8_fp4_mega_moe(
|
||||
y,
|
||||
l1_weights, l2_weights,
|
||||
cumulative_local_expert_recv_stats,
|
||||
sym_buffer.buffer,
|
||||
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
|
||||
sym_buffer.num_max_tokens_per_rank,
|
||||
|
||||
448
scripts/quick_plot_pm.py
Normal file
448
scripts/quick_plot_pm.py
Normal file
@@ -0,0 +1,448 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Plot a curated set of NCU PM metrics from an .ncu-rep report.
|
||||
|
||||
Usage:
|
||||
python scripts/quick_plot_pm.py [report.ncu-rep]
|
||||
|
||||
By default the script saves a PNG next to the report.
|
||||
With --interactive, it opens a Qt window instead.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import io
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetricSpec:
|
||||
name: str
|
||||
metric: str
|
||||
kind: str
|
||||
category: str
|
||||
aliases: tuple[str, ...] = ()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedMetricSpec:
|
||||
name: str
|
||||
metric: str
|
||||
kind: str
|
||||
category: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetricSeries:
|
||||
name: str
|
||||
metric: str
|
||||
category: str
|
||||
unit: str
|
||||
values: tuple[float, ...]
|
||||
|
||||
|
||||
CATEGORY_ORDER = [
|
||||
"Overview",
|
||||
"SM",
|
||||
"L1",
|
||||
"L2",
|
||||
"DRAM",
|
||||
"Interconnect",
|
||||
]
|
||||
|
||||
|
||||
KIND_SUFFIXES = {
|
||||
"pct_peak": [".avg.pct_of_peak_sustained_elapsed"],
|
||||
"pct": [".pct", ".avg.pct_of_peak_sustained_elapsed"],
|
||||
"avg": [".avg"],
|
||||
"sum": [".sum"],
|
||||
"avg_per_second": [".avg.per_second"],
|
||||
"sum_per_second": [".sum.per_second"],
|
||||
"avg_per_cycle_active": [".avg.per_cycle_active"],
|
||||
"avg_per_cycle_elapsed": [".avg.per_cycle_elapsed"],
|
||||
"sum_per_cycle_elapsed": [".sum.per_cycle_elapsed"],
|
||||
}
|
||||
|
||||
|
||||
# Curated from scripts/ncu-metrics.txt, with a few corrections against
|
||||
# `ncu --query-metrics --chip gb100`:
|
||||
# - Blocks launched uses `gr__ctas_launched_realtime`
|
||||
# - SM active cycles uses `sm__cycles_active`
|
||||
# - L2 throughput for GCC requests uses `lts__t_sector_throughput_srcunit_gcc`
|
||||
# - C2C throughput uses `ctc__throughput`
|
||||
# - NVLink RX metrics use the `NVLRX` domain
|
||||
CURATED_METRICS = [
|
||||
MetricSpec("Blocks Launched", "FE_B.TriageCompute.gr__ctas_launched_realtime", "sum_per_cycle_elapsed", "Overview"),
|
||||
MetricSpec("Average Blocks Active", "TPC.TriageCompute.tpc__ctas_active_realtime", "avg_per_cycle_elapsed", "Overview"),
|
||||
MetricSpec("Total Blocks Active", "TPC.TriageCompute.tpc__ctas_active_realtime", "sum_per_cycle_elapsed", "Overview"),
|
||||
MetricSpec("Average CGAs Active", "GPC_B.TriageCompute.gpc__cgas_active_realtime", "avg_per_cycle_elapsed", "Overview"),
|
||||
MetricSpec("Total CGAs Active", "GPC_B.TriageCompute.gpc__cgas_active_realtime", "sum_per_cycle_elapsed", "Overview"),
|
||||
MetricSpec("SM Active Cycles", "TPC.TriageCompute.sm__cycles_active", "avg", "SM"),
|
||||
MetricSpec("Executed IPC Active", "TPC.TriageCompute.sm__inst_executed_realtime", "avg_per_cycle_active", "SM"),
|
||||
MetricSpec("Executed IPC Elapsed", "TPC.TriageCompute.sm__inst_executed_realtime", "avg_per_cycle_elapsed", "SM"),
|
||||
MetricSpec("SM Throughput", "TPC.TriageCompute.sm__inst_executed_realtime", "pct_peak", "SM"),
|
||||
MetricSpec("SM ALU Pipe Throughput", "TPC.TriageCompute.sm__inst_executed_pipe_alu_realtime", "pct_peak", "SM"),
|
||||
MetricSpec("SM FMA Pipe Throughput", "TPC.TriageCompute.sm__pipe_fma_cycles_active_realtime", "pct_peak", "SM"),
|
||||
MetricSpec("SM FMA Heavy Pipe Throughput", "TPC.TriageCompute.sm__pipe_fmaheavy_cycles_active_realtime", "pct_peak", "SM"),
|
||||
MetricSpec("SM FMA Light Pipe Throughput", "TPC.TriageCompute.sm__pipe_fmalite_cycles_active_realtime", "pct_peak", "SM"),
|
||||
MetricSpec("SM Tensor Pipe Throughput", "TPC.TriageCompute.sm__pipe_tensor_cycles_active_realtime", "pct_peak", "SM"),
|
||||
MetricSpec("SM TMEM Pipe Throughput", "SM_A.TriageCompute.sm__mem_tensor_cycles_active_realtime", "pct_peak", "SM"),
|
||||
MetricSpec("SM Uniform Pipe Throughput", "SM_A.TriageCompute.sm__inst_executed_pipe_uniform_realtime", "pct_peak", "SM"),
|
||||
MetricSpec("SM XU Pipe Throughput", "SM_A.TriageCompute.sm__inst_executed_pipe_xu_realtime", "pct_peak", "SM"),
|
||||
MetricSpec("L1 Throughput", "SM_A.TriageCompute.l1tex__throughput", "pct_peak", "L1"),
|
||||
MetricSpec("L1 Sectors", "SM_B.TriageCompute.l1tex__t_sectors", "sum", "L1"),
|
||||
MetricSpec("L1 Hit Rate", "SM_B.TriageCompute.l1tex__t_sector_hit_rate", "pct", "L1"),
|
||||
MetricSpec("L1 Lookup Hit", "SM_B.TriageCompute.l1tex__t_sectors_lookup_hit", "sum", "L1"),
|
||||
MetricSpec("L1 Lookup Miss", "SM_B.TriageCompute.l1tex__t_sectors_lookup_miss", "sum", "L1"),
|
||||
MetricSpec("L1 Wavefronts (Data)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts", "avg", "L1"),
|
||||
MetricSpec("L1 Wavefronts (LGDS)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts_mem_lgds", "avg", "L1"),
|
||||
MetricSpec("L1 Wavefronts (Shared)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts_mem_shared", "avg", "L1"),
|
||||
MetricSpec("L2 Throughput", "LTS.TriageCompute.lts__throughput", "pct_peak", "L2"),
|
||||
MetricSpec("L2 Throughput for L1 Requests", "LTS.TriageCompute.lts__t_sector_throughput_srcunit_tex", "pct_peak", "L2"),
|
||||
MetricSpec("L2 Throughput for GCC Requests", "LTS.TriageCompute.lts__t_sector_throughput_srcunit_gcc", "pct_peak", "L2"),
|
||||
MetricSpec("L2 Throughput to DRAM", "LTS.TriageCompute.lts__t_sector_throughput_srcnode_fbp", "pct_peak", "L2"),
|
||||
MetricSpec("SysL2 Throughput to Peer Memory", "SYSLTS.TriageCompute.syslts__t_sector_throughput_aperture_peer", "pct_peak", "L2"),
|
||||
MetricSpec("SysL2 Throughput to System Memory", "SYSLTS.TriageCompute.syslts__t_sector_throughput_aperture_sysmem", "pct_peak", "L2"),
|
||||
MetricSpec("L2 Hit Rate", "LTS.TriageCompute.lts__average_t_sector_hit_rate_realtime", "pct", "L2"),
|
||||
MetricSpec("L2 Hit Rate From L1", "LTS.TriageCompute.lts__average_t_sector_hit_rate_srcunit_tex_realtime", "pct", "L2"),
|
||||
MetricSpec("DRAM Frequency", "FBSP.TriageCompute.dram__cycles_elapsed", "avg_per_second", "DRAM"),
|
||||
MetricSpec("DRAM Throughput", "FBSP.TriageCompute.dram__throughput", "pct_peak", "DRAM"),
|
||||
MetricSpec("DRAM Read Throughput", "FBSP.TriageCompute.dram__read_throughput", "pct_peak", "DRAM"),
|
||||
MetricSpec("DRAM Write Throughput", "FBSP.TriageCompute.dram__write_throughput", "pct_peak", "DRAM"),
|
||||
MetricSpec("C2C Throughput", "TriageCompute.ctc__throughput", "pct_peak", "Interconnect", aliases=("TriageCompute.ctx__throughput",)),
|
||||
MetricSpec("NVLink Transmitted Throughput", "NVLTX.TriageCompute.nvltx__bytes", "pct_peak", "Interconnect"),
|
||||
MetricSpec("NVLink Received Throughput", "NVLRX.TriageCompute.nvlrx__bytes", "pct_peak", "Interconnect"),
|
||||
MetricSpec("NVLink Transmitted Bandwidth", "NVLTX.TriageCompute.nvltx__bytes", "sum_per_second", "Interconnect"),
|
||||
MetricSpec("NVLink Received Bandwidth", "NVLRX.TriageCompute.nvlrx__bytes", "sum_per_second", "Interconnect"),
|
||||
MetricSpec("PCIe Throughput", "PCI.TriageCompute.pcie__throughput", "pct_peak", "Interconnect"),
|
||||
MetricSpec("PCIe Read Bandwidth", "PCI.TriageCompute.pcie__read_bytes", "sum_per_second", "Interconnect"),
|
||||
MetricSpec("PCIe Write Bandwidth", "PCI.TriageCompute.pcie__write_bytes", "sum_per_second", "Interconnect"),
|
||||
]
|
||||
|
||||
|
||||
def _run_csv_command(command, timeout):
|
||||
result = subprocess.run(command, capture_output=True, text=True, timeout=timeout)
|
||||
if result.returncode != 0 and not result.stdout:
|
||||
return None
|
||||
reader = csv.reader(io.StringIO(result.stdout))
|
||||
return list(reader)
|
||||
|
||||
|
||||
def _query_available_metrics(chip):
|
||||
result = subprocess.run(
|
||||
["ncu", "--query-metrics", "--chip", chip],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(result.stderr.strip() or f"failed to query metrics for chip {chip}")
|
||||
|
||||
metrics = set()
|
||||
for line in result.stdout.splitlines():
|
||||
parts = line.split()
|
||||
if not parts:
|
||||
continue
|
||||
token = parts[0]
|
||||
if "__" not in token:
|
||||
continue
|
||||
metrics.add(token)
|
||||
return metrics
|
||||
|
||||
|
||||
def _metric_candidates(metric):
|
||||
candidates = [metric]
|
||||
marker = ".TriageCompute."
|
||||
if marker in metric:
|
||||
candidates.append(metric.split(marker, 1)[1])
|
||||
return candidates
|
||||
|
||||
|
||||
def resolve_metric_specs(chip):
|
||||
available = _query_available_metrics(chip)
|
||||
resolved = []
|
||||
missing = []
|
||||
|
||||
for spec in CURATED_METRICS:
|
||||
candidates = []
|
||||
for metric in (spec.metric, *spec.aliases):
|
||||
candidates.extend(_metric_candidates(metric))
|
||||
|
||||
actual_metric = next((metric for metric in candidates if metric in available), None)
|
||||
if actual_metric is None:
|
||||
missing.append(spec)
|
||||
continue
|
||||
|
||||
resolved.append(ResolvedMetricSpec(spec.name, actual_metric, spec.kind, spec.category))
|
||||
|
||||
return resolved, missing
|
||||
|
||||
|
||||
def _parse_metric_values(raw):
|
||||
if not raw or raw == "no data":
|
||||
return ()
|
||||
|
||||
try:
|
||||
if raw.startswith("(") and raw.endswith(")"):
|
||||
rest = raw[1:-1]
|
||||
return tuple(float(v.strip().replace(",", "")) for v in rest.split(";") if v.strip())
|
||||
if " (" in raw:
|
||||
_agg, rest = raw.split(" (", 1)
|
||||
rest = rest.rstrip(")")
|
||||
return tuple(float(v.strip().replace(",", "")) for v in rest.split(";") if v.strip())
|
||||
return (float(raw.replace(",", "")),)
|
||||
except ValueError:
|
||||
return ()
|
||||
|
||||
|
||||
def _probe_metric_series(report, metric_name):
|
||||
rows = _run_csv_command(
|
||||
[
|
||||
"ncu",
|
||||
"--import",
|
||||
report,
|
||||
"--page",
|
||||
"raw",
|
||||
"--csv",
|
||||
"--metrics",
|
||||
metric_name,
|
||||
"--print-metric-instances",
|
||||
"values",
|
||||
],
|
||||
timeout=60,
|
||||
)
|
||||
if not rows or len(rows) < 3 or len(rows[0]) <= 11:
|
||||
return None
|
||||
|
||||
header, units, row = rows[0], rows[1], rows[2]
|
||||
unit = units[11] if len(units) > 11 else ""
|
||||
raw = row[11] if len(row) > 11 else ""
|
||||
values = _parse_metric_values(raw)
|
||||
return header[11], unit, values
|
||||
|
||||
|
||||
def collect_metric_series(report, resolved_specs):
|
||||
collected = []
|
||||
skipped = []
|
||||
|
||||
for spec in resolved_specs:
|
||||
series = None
|
||||
for suffix in KIND_SUFFIXES[spec.kind]:
|
||||
probe = _probe_metric_series(report, f"{spec.metric}{suffix}")
|
||||
if probe is None:
|
||||
continue
|
||||
full_metric, unit, values = probe
|
||||
if len(values) > 1:
|
||||
series = MetricSeries(spec.name, full_metric, spec.category, unit, values)
|
||||
break
|
||||
|
||||
if series is None:
|
||||
skipped.append(spec)
|
||||
continue
|
||||
|
||||
collected.append(series)
|
||||
|
||||
return collected, skipped
|
||||
|
||||
|
||||
def _format_value(value):
|
||||
if value == 0:
|
||||
return "0"
|
||||
abs_value = abs(value)
|
||||
if abs_value >= 1e12:
|
||||
return f"{value / 1e12:.2f} T"
|
||||
if abs_value >= 1e9:
|
||||
return f"{value / 1e9:.2f} G"
|
||||
if abs_value >= 1e6:
|
||||
return f"{value / 1e6:.2f} M"
|
||||
if abs_value >= 1e3:
|
||||
return f"{value / 1e3:.2f} K"
|
||||
if abs_value >= 1:
|
||||
return f"{value:.1f}"
|
||||
return f"{value:.2f}"
|
||||
|
||||
|
||||
def _format_with_unit(value, unit):
|
||||
if not unit:
|
||||
return _format_value(value)
|
||||
return f"{_format_value(value)} {unit}"
|
||||
|
||||
|
||||
def plot_pm(report, metrics, save=False):
|
||||
"""Plot curated PM metrics as shared-x subplots in a light theme."""
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.gridspec import GridSpec
|
||||
|
||||
if not metrics:
|
||||
print("No curated metrics had time-series data in the report.")
|
||||
return
|
||||
|
||||
bg_fig = "#ffffff"
|
||||
bg_row = "#f6f8fb"
|
||||
text_primary = "#1f2937"
|
||||
text_secondary = "#6b7280"
|
||||
text_header = "#111827"
|
||||
grid_color = "#d7deea"
|
||||
border = "#c7d0dd"
|
||||
|
||||
wave_colors = {
|
||||
"Overview": "#7c8aa5",
|
||||
"SM": "#4f87c2",
|
||||
"L1": "#2f9d8f",
|
||||
"L2": "#dd8452",
|
||||
"DRAM": "#c95d63",
|
||||
"Interconnect": "#8c6bb1",
|
||||
}
|
||||
|
||||
category_rank = {category: index for index, category in enumerate(CATEGORY_ORDER)}
|
||||
metrics = sorted(metrics, key=lambda item: (category_rank.get(item.category, 99), item.name))
|
||||
|
||||
row_h = 0.55
|
||||
label_w = 3.6
|
||||
plot_w = 14.0
|
||||
fig_w = label_w + plot_w
|
||||
fig_h = row_h * len(metrics) + 0.6
|
||||
|
||||
fig = plt.figure(figsize=(fig_w, fig_h), facecolor=bg_fig)
|
||||
gs = GridSpec(
|
||||
len(metrics),
|
||||
1,
|
||||
figure=fig,
|
||||
left=label_w / fig_w,
|
||||
right=0.97,
|
||||
top=1 - 0.45 / fig_h,
|
||||
bottom=0.35 / fig_h,
|
||||
hspace=0.18,
|
||||
)
|
||||
axes = [fig.add_subplot(gs[i, 0]) for i in range(len(metrics))]
|
||||
|
||||
prev_category = None
|
||||
for idx, metric in enumerate(metrics):
|
||||
ax = axes[idx]
|
||||
values = np.array(metric.values)
|
||||
x = np.arange(len(values))
|
||||
wave_color = wave_colors.get(metric.category, "#5b9bd5")
|
||||
|
||||
ax.set_facecolor(bg_row)
|
||||
ax.fill_between(x, values, alpha=0.35, color=wave_color, linewidth=0)
|
||||
ax.plot(x, values, linewidth=0.8, color=wave_color)
|
||||
|
||||
ax.set_xlim(0, len(values) - 1)
|
||||
if metric.unit == "%":
|
||||
ax.set_ylim(0, 100)
|
||||
else:
|
||||
ymax = np.max(values)
|
||||
ax.set_ylim(0, ymax * 1.15 if ymax > 0 else 1)
|
||||
|
||||
ax.grid(True, axis="both", color=grid_color, linewidth=0.5, alpha=0.85)
|
||||
ax.tick_params(axis="both", colors=text_secondary, labelsize=6, length=0)
|
||||
|
||||
if idx < len(metrics) - 1:
|
||||
ax.tick_params(axis="x", labelbottom=False)
|
||||
else:
|
||||
ax.set_xlabel("Sample Index", fontsize=8, color=text_secondary)
|
||||
|
||||
ymin_v, ymax_v = ax.get_ylim()
|
||||
ax.set_yticks([ymin_v, ymax_v])
|
||||
ax.set_yticklabels([_format_value(ymin_v), _format_value(ymax_v)], fontsize=6, color=text_secondary)
|
||||
|
||||
peak = np.max(values)
|
||||
ax.text(
|
||||
1.005,
|
||||
0.5,
|
||||
_format_with_unit(peak, metric.unit),
|
||||
transform=ax.transAxes,
|
||||
fontsize=7,
|
||||
color=text_secondary,
|
||||
va="center",
|
||||
ha="left",
|
||||
family="monospace",
|
||||
)
|
||||
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color(border)
|
||||
spine.set_linewidth(0.5)
|
||||
|
||||
if metric.category != prev_category:
|
||||
cat_y = ax.get_position().y1 + 0.008
|
||||
fig.text(
|
||||
0.005,
|
||||
cat_y,
|
||||
f" {metric.category}",
|
||||
fontsize=8.5,
|
||||
fontweight="bold",
|
||||
color=text_header,
|
||||
va="bottom",
|
||||
family="sans-serif",
|
||||
transform=fig.transFigure,
|
||||
bbox=dict(boxstyle="square,pad=0.15", facecolor="#e9eef5", edgecolor="none"),
|
||||
)
|
||||
prev_category = metric.category
|
||||
|
||||
label_y = (ax.get_position().y0 + ax.get_position().y1) / 2
|
||||
fig.text(
|
||||
label_w / fig_w - 0.012,
|
||||
label_y,
|
||||
metric.name,
|
||||
fontsize=7.5,
|
||||
color=text_primary,
|
||||
va="center",
|
||||
ha="right",
|
||||
family="sans-serif",
|
||||
transform=fig.transFigure,
|
||||
)
|
||||
|
||||
fig.text(
|
||||
0.5,
|
||||
1 - 0.15 / fig_h,
|
||||
f"PM Sampling - {report}",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
color=text_header,
|
||||
ha="center",
|
||||
va="top",
|
||||
family="sans-serif",
|
||||
transform=fig.transFigure,
|
||||
)
|
||||
|
||||
if save:
|
||||
out_path = report.replace(".ncu-rep", ".pm_sampling.png")
|
||||
fig.savefig(out_path, dpi=150, facecolor=bg_fig, bbox_inches="tight", pad_inches=0.2)
|
||||
print(f"Saved: {out_path}")
|
||||
plt.close(fig)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="NCU PM Sampling plotter")
|
||||
parser.add_argument("report", nargs="?", default="mega-moe-kk.3.ncu-rep", help="Path to .ncu-rep file")
|
||||
parser.add_argument("--chip", default="gb100", help="Chip name used for `ncu --query-metrics`")
|
||||
parser.add_argument("--interactive", action="store_true", help="Open an interactive Qt window instead of saving a PNG")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.interactive:
|
||||
matplotlib.use("QtAgg")
|
||||
else:
|
||||
matplotlib.use("Agg")
|
||||
|
||||
resolved_specs, missing_specs = resolve_metric_specs(args.chip)
|
||||
if missing_specs:
|
||||
print(f"Skipped {len(missing_specs)} curated metrics not available on {args.chip}.")
|
||||
for spec in missing_specs:
|
||||
print(f" missing: {spec.name} -> {spec.metric}")
|
||||
|
||||
metric_series, skipped_specs = collect_metric_series(args.report, resolved_specs)
|
||||
if skipped_specs:
|
||||
print(f"Skipped {len(skipped_specs)} curated metrics with no time-series data in {args.report}.")
|
||||
for spec in skipped_specs:
|
||||
print(f" no series: {spec.name} -> {spec.metric}")
|
||||
|
||||
plot_pm(args.report, metric_series, save=not args.interactive)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
89
scripts/run_ncu_mega_moe.sh
Executable file
89
scripts/run_ncu_mega_moe.sh
Executable file
@@ -0,0 +1,89 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# parse num-processes, output_dir and separate python args
|
||||
num_processes=8
|
||||
output_dir=work
|
||||
python_args=()
|
||||
for ((arg_idx = 1; arg_idx <= $#; ++arg_idx)); do
|
||||
arg="${!arg_idx}"
|
||||
case "$arg" in
|
||||
--num-processes)
|
||||
python_args+=("$arg")
|
||||
if ((arg_idx < $#)); then
|
||||
((arg_idx++))
|
||||
num_processes="${!arg_idx}"
|
||||
python_args+=("$num_processes")
|
||||
fi
|
||||
;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [--num-processes N] [--output DIR] [python args...]"
|
||||
exit 0
|
||||
;;
|
||||
--num-processes=*)
|
||||
num_processes="${arg#*=}"
|
||||
python_args+=("$arg")
|
||||
;;
|
||||
-o|--output)
|
||||
if ((arg_idx < $#)); then
|
||||
((arg_idx++))
|
||||
output_dir="${!arg_idx}"
|
||||
fi
|
||||
;;
|
||||
--output=*)
|
||||
output_dir="${arg#*=}"
|
||||
;;
|
||||
*)
|
||||
python_args+=("$arg")
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "Python Args: ${python_args[*]}"
|
||||
echo "Num Processes: $num_processes"
|
||||
echo "Output Dir: $output_dir"
|
||||
mkdir -p $output_dir
|
||||
|
||||
export DG_JIT_WITH_LINEINFO=1 # for source counters
|
||||
|
||||
echo "Warm up JIT cache"
|
||||
python tests/test_mega_moe.py --ncu-profile-only "${python_args[@]}"
|
||||
|
||||
sleep 2
|
||||
|
||||
ncu_args=(
|
||||
--config-file off
|
||||
--force-overwrite
|
||||
--kernel-name sm100_fp8_fp4_mega_moe_impl
|
||||
--import-source yes
|
||||
--replay-mode application
|
||||
--section PmSampling
|
||||
--section SourceCounters
|
||||
--rule LocalMemoryUsage
|
||||
--launch-skip 0
|
||||
--launch-count 1
|
||||
--lockstep-kernel-launch
|
||||
--communicator tcp
|
||||
--clock-control none
|
||||
--pm-sampling-interval 1000
|
||||
--pm-sampling-max-passes 1
|
||||
--disable-pm-warp-sampling
|
||||
--communicator-tcp-num-peers "$num_processes"
|
||||
--kill yes
|
||||
--app-replay-buffer memory
|
||||
)
|
||||
|
||||
echo "Run Job"
|
||||
|
||||
for ((i = 0; i < num_processes; ++i)); do
|
||||
ncu ${ncu_args[@]} -o "${output_dir%/}/mega-moe.$i" \
|
||||
python tests/test_mega_moe.py \
|
||||
--local-rank-idx=$i \
|
||||
--ncu-profile-only \
|
||||
"${python_args[@]}" &
|
||||
done
|
||||
|
||||
echo "Waiting"
|
||||
wait
|
||||
echo "Done"
|
||||
@@ -253,36 +253,61 @@ def test_paged_mqa_logits():
|
||||
|
||||
def enumerate_paged_mqa_logits():
|
||||
arch_major = get_arch_major()
|
||||
for is_fp4 in ((True, False) if arch_major == 10 else (False, )):
|
||||
for logits_dtype in (torch.float, torch.bfloat16):
|
||||
for block_kv in ((32, 64) if arch_major == 10 else (64, )):
|
||||
for use_2d_context_lens, clean_logits in [(True, False)]:
|
||||
for batch_size in (256, ):
|
||||
for next_n in (1, 2, 4, 5, 6) if arch_major == 10 else (1, 2):
|
||||
for num_heads, head_dim in [(64, 128)]:
|
||||
for avg_kv in (8192, 32768):
|
||||
yield is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, num_heads, head_dim, avg_kv
|
||||
for is_varlen in ((True, False) if arch_major == 10 else (False, )):
|
||||
for is_fp4 in ((True, False) if arch_major == 10 else (False, )):
|
||||
for logits_dtype in (torch.float, torch.bfloat16):
|
||||
for block_kv in ((32, 64) if arch_major == 10 else (64, )):
|
||||
for use_2d_context_lens, clean_logits in [(True, False)]:
|
||||
for batch_size in (256, ):
|
||||
for next_n in ((1, ) if is_varlen else ((1, 2, 4, 5, 6) if arch_major == 10 else (1, 2))):
|
||||
for max_tokens_per_batch in ((1, 4, 10) if is_varlen else (1, )):
|
||||
for num_heads, head_dim in [(64, 128)]:
|
||||
for avg_kv in (8192, 32768):
|
||||
yield is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv
|
||||
|
||||
|
||||
print('Testing FP8/FP4 Paged MQA Logits:')
|
||||
max_model_len = 111 * 1024
|
||||
num_total_blocks = max_model_len * 5
|
||||
|
||||
for is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, num_heads, head_dim, avg_kv in enumerate_paged_mqa_logits():
|
||||
for is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv in enumerate_paged_mqa_logits():
|
||||
# Varlen: flatten raw_batch_size sequences with variable tokens into (batch_size, 1, ...)
|
||||
raw_batch_size, raw_next_n = batch_size, next_n
|
||||
if is_varlen:
|
||||
tokens_per_seq = torch.randint(1, max_tokens_per_batch + 1, (raw_batch_size,), device='cuda', dtype=torch.int)
|
||||
indices = torch.arange(raw_batch_size, device='cuda', dtype=torch.int).repeat_interleave(tokens_per_seq)
|
||||
batch_size, next_n = tokens_per_seq.sum().item(), 1
|
||||
else:
|
||||
tokens_per_seq, indices = None, None
|
||||
|
||||
# Generate random inputs
|
||||
q = torch.randn((batch_size, next_n, num_heads, head_dim), device='cuda', dtype=torch.bfloat16)
|
||||
kv_cache = torch.randn((num_total_blocks, block_kv, 1, head_dim), device='cuda', dtype=torch.bfloat16)
|
||||
weights = torch.randn((batch_size * next_n, num_heads), device='cuda', dtype=torch.float)
|
||||
context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size,), device='cuda', dtype=torch.int)
|
||||
context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (raw_batch_size,), device='cuda', dtype=torch.int)
|
||||
|
||||
# Assign block tables
|
||||
num_blocks_per_query = ceil_div(context_lens, block_kv)
|
||||
block_table = torch.empty((batch_size, num_blocks_per_query.max().item()), device='cuda', dtype=torch.int)
|
||||
if is_varlen:
|
||||
max_ctx_len_per_seq = context_lens + (tokens_per_seq - 1)
|
||||
else:
|
||||
max_ctx_len_per_seq = context_lens
|
||||
|
||||
# Assign block tables (per-sequence, sized by the largest ctx_len within the sequence)
|
||||
seq_sum_lens = context_lens.sum().item()
|
||||
num_blocks_per_query = ceil_div(max_ctx_len_per_seq, block_kv)
|
||||
block_table = torch.empty((raw_batch_size, num_blocks_per_query.max().item()), device='cuda', dtype=torch.int)
|
||||
block_idx_pool = torch.randperm(num_total_blocks, device='cuda', dtype=torch.int)
|
||||
offset = 0
|
||||
for i, num_blocks in enumerate(num_blocks_per_query.tolist()):
|
||||
block_table[i, :num_blocks] = block_idx_pool[offset : offset + num_blocks]
|
||||
offset += num_blocks
|
||||
if is_varlen:
|
||||
context_lens = context_lens.repeat_interleave(tokens_per_seq)
|
||||
offsets_within_seq = torch.cat([
|
||||
torch.arange(n.item(), device='cuda', dtype=torch.int)
|
||||
for n in tokens_per_seq
|
||||
])
|
||||
context_lens = context_lens + offsets_within_seq
|
||||
block_table = block_table.repeat_interleave(tokens_per_seq, dim=0)
|
||||
|
||||
# Calculate reference logits
|
||||
ref_logits = ref_paged_mqa_logits(q, kv_cache, weights, context_lens, block_table, max_model_len, use_2d_context_lens)
|
||||
@@ -304,9 +329,14 @@ def test_paged_mqa_logits():
|
||||
# Prepare masks and context lengths with NextN
|
||||
positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1)
|
||||
if use_2d_context_lens:
|
||||
context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int()
|
||||
# Ensure last token matches actual length
|
||||
context_lens_nextn[:, -1] = context_lens
|
||||
if is_varlen:
|
||||
# Varlen: context_lens is already per-token (shape [total_tokens]);
|
||||
# just reshape to (total_tokens, 1) so each token keeps its own ctx_len.
|
||||
context_lens_nextn = context_lens.view(-1, 1)
|
||||
else:
|
||||
context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int()
|
||||
# Ensure last token matches actual length
|
||||
context_lens_nextn[:, -1] = context_lens
|
||||
ref_neginf_mask = ~(positions < context_lens_nextn.view(-1, 1))
|
||||
else:
|
||||
context_lens_nextn = context_lens
|
||||
@@ -318,8 +348,9 @@ def test_paged_mqa_logits():
|
||||
kernel_kwargs = dict(
|
||||
q=q_in, kv_cache=kv_in, weights=weights,
|
||||
context_lens=context_lens_nextn, block_table=block_table,
|
||||
schedule_meta=deep_gemm.get_paged_mqa_logits_metadata(context_lens_nextn, block_kv, deep_gemm.get_num_sms()),
|
||||
max_context_len=max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype
|
||||
schedule_meta=deep_gemm.get_paged_mqa_logits_metadata(context_lens_nextn, block_kv, deep_gemm.get_num_sms(), indices=indices),
|
||||
max_context_len=max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype,
|
||||
indices=indices,
|
||||
)
|
||||
logits = deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs)
|
||||
|
||||
@@ -342,11 +373,15 @@ def test_paged_mqa_logits():
|
||||
sum_lens = context_lens.sum().item()
|
||||
tflops_calc = 2 * sum_lens * next_n * num_heads * head_dim / 1e12
|
||||
kv_bytes_per_token = head_dim / (2 if is_fp4 else 1) + 4
|
||||
total_bytes = count_bytes(q, weights) + sum_lens * kv_bytes_per_token + (sum_lens * next_n * logits_dtype.itemsize)
|
||||
# KV is read once per sequence; for varlen sum_lens overcounts (per-token), so use seq_sum_lens
|
||||
kv_sum_lens = seq_sum_lens if is_varlen else sum_lens
|
||||
total_bytes = count_bytes(q, weights) + kv_sum_lens * kv_bytes_per_token + (sum_lens * next_n * logits_dtype.itemsize)
|
||||
|
||||
t, clean_t = bench_kineto(lambda: deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs), ('paged_mqa_logits', 'clean_logits'))
|
||||
print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, BLOCK_KV={block_kv}, BSZ={batch_size:3}, NextN={next_n:1}, H={num_heads:2}, D={head_dim:2}, L={avg_kv:6}: '
|
||||
print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, BLOCK_KV={block_kv}, BSZ={raw_batch_size:3}, NextN={raw_next_n:1}, H={num_heads:2}, D={head_dim:2}, L={avg_kv:6}: '
|
||||
f'{tflops_calc / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, {total_bytes / t / 1e9:4.0f} GB/s', end='')
|
||||
if is_varlen:
|
||||
print(f' | Varlen, MaxTPB={max_tokens_per_batch}, NumTokens={batch_size}', end='')
|
||||
print(f' | clean: {clean_t*1e6:3.0f} us' if clean_logits else '')
|
||||
print()
|
||||
|
||||
|
||||
@@ -9,24 +9,28 @@ from typing import Tuple
|
||||
import deep_gemm
|
||||
from deep_gemm.utils import per_token_cast_to_fp4, per_token_cast_to_fp8
|
||||
from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather
|
||||
from deep_gemm.testing import bench, bench_kineto, calc_diff
|
||||
from deep_gemm.testing import bench_kineto
|
||||
|
||||
# Load legacy implements from third-party
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import deep_ep
|
||||
import importlib.util
|
||||
from tilelang.profiler.bench import do_bench
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'tilelang_ops',
|
||||
os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py'))
|
||||
tilelang_ops = importlib.util.module_from_spec(spec)
|
||||
sys.modules['tilelang_ops'] = tilelang_ops
|
||||
spec.loader.exec_module(tilelang_ops)
|
||||
is_legacy_loaded = True
|
||||
except Exception as ex:
|
||||
print(f'Failed to load legacy code: {ex}, skip baseline benchmarking')
|
||||
is_legacy_loaded = False
|
||||
|
||||
def import_baseline():
|
||||
# Load legacy implements from third-party
|
||||
deep_ep, tilelang_ops, do_bench, is_legacy_loaded = None, None, None, False
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import deep_ep
|
||||
import importlib.util
|
||||
from tilelang.profiler.bench import do_bench
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'tilelang_ops',
|
||||
os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py'))
|
||||
tilelang_ops = importlib.util.module_from_spec(spec)
|
||||
sys.modules['tilelang_ops'] = tilelang_ops
|
||||
spec.loader.exec_module(tilelang_ops)
|
||||
is_legacy_loaded = True
|
||||
except Exception as ex:
|
||||
dist_print(f'Failed to load legacy code: {ex}, skip baseline benchmarking', once_in_node=True)
|
||||
dist_print(once_in_node=True)
|
||||
return deep_ep, tilelang_ops, do_bench, is_legacy_loaded
|
||||
|
||||
|
||||
# TODO: skip the test for SM90
|
||||
@@ -51,29 +55,13 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
num_max_tokens_per_rank, num_topk,
|
||||
hidden, intermediate_hidden
|
||||
)
|
||||
dist_print('Config:', once_in_node=True)
|
||||
dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True)
|
||||
dist_print(f' > Hidden: {hidden}', once_in_node=True)
|
||||
dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True)
|
||||
dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True)
|
||||
dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True)
|
||||
dist_print(once_in_node=True)
|
||||
|
||||
# Non-overlapped baseline: EP dispatch + GEMM + EP combine
|
||||
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
|
||||
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
|
||||
ep_buffer = deep_ep.ElasticBuffer(
|
||||
group,
|
||||
num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden,
|
||||
num_topk=num_topk, use_fp8_dispatch=True,
|
||||
explicitly_destroy=True,
|
||||
allow_multiple_reduction=False,
|
||||
gpu_timeout_secs=10, cpu_timeout_secs=30
|
||||
) if is_legacy_loaded else None
|
||||
|
||||
# Create inputs
|
||||
# noinspection PyGlobalUndefined
|
||||
def create_inputs():
|
||||
global x, topk_idx, topk_weights, l1_weights, l2_weights, transformed_l1_weights, transformed_l2_weights
|
||||
global cumulative_local_expert_recv_stats_fused
|
||||
global cumulative_local_expert_recv_stats_baseline
|
||||
x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
l1_weights = torch.randn(
|
||||
(num_experts_per_rank, intermediate_hidden * 2, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
@@ -81,6 +69,9 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
(num_experts_per_rank, hidden, intermediate_hidden), dtype=torch.bfloat16, device='cuda')
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda')
|
||||
topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)
|
||||
cumulative_local_expert_recv_stats_fused = torch.randint(
|
||||
0, 100, (num_experts_per_rank, ), dtype=torch.int, device='cuda')
|
||||
cumulative_local_expert_recv_stats_baseline = cumulative_local_expert_recv_stats_fused.clone()
|
||||
if args.masked_ratio > 0:
|
||||
rand_mask = torch.rand_like(topk_idx, dtype=torch.float)
|
||||
topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1)
|
||||
@@ -109,12 +100,67 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
l2_weights = cast_grouped_weights_to_fp4(l2_weights)
|
||||
transformed_l1_weights, transformed_l2_weights = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights)
|
||||
|
||||
# Run fused mega MoE
|
||||
# NOTES: copy x into buffer before each call because debug mode zeros the entire buffer
|
||||
def run_fused():
|
||||
buffer.x[:num_tokens].copy_(x[0])
|
||||
buffer.x_sf[:num_tokens].copy_(x[1])
|
||||
buffer.topk_idx[:num_tokens].copy_(topk_idx)
|
||||
buffer.topk_weights[:num_tokens].copy_(topk_weights)
|
||||
|
||||
y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
# noinspection PyTypeChecker
|
||||
deep_gemm.fp8_fp4_mega_moe(
|
||||
y,
|
||||
transformed_l1_weights, transformed_l2_weights,
|
||||
buffer,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_fused,
|
||||
activation_clamp=args.activation_clamp,
|
||||
fast_math=bool(args.fast_math)
|
||||
)
|
||||
return y, cumulative_local_expert_recv_stats_fused
|
||||
|
||||
dist_print('Config:', once_in_node=True)
|
||||
dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True)
|
||||
dist_print(f' > Hidden: {hidden}', once_in_node=True)
|
||||
dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True)
|
||||
dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True)
|
||||
dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True)
|
||||
dist_print(once_in_node=True)
|
||||
|
||||
# Only do NCU profiling
|
||||
if args.ncu_profile_only:
|
||||
create_inputs()
|
||||
dist_print(f'Run fused kernel:', once_in_node=True)
|
||||
run_fused()
|
||||
dist_print(f' > Done, exiting', once_in_node=True)
|
||||
|
||||
# Destroy and exit
|
||||
dist.barrier()
|
||||
buffer.destroy()
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Non-overlapped baseline: EP dispatch + GEMM + EP combine
|
||||
deep_ep, tilelang_ops, tilelang_bench, is_legacy_loaded = import_baseline()
|
||||
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
|
||||
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
|
||||
ep_buffer = deep_ep.ElasticBuffer(
|
||||
group,
|
||||
num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden,
|
||||
num_topk=num_topk, use_fp8_dispatch=True,
|
||||
explicitly_destroy=True,
|
||||
allow_multiple_reduction=False,
|
||||
gpu_timeout_secs=10, cpu_timeout_secs=30
|
||||
) if is_legacy_loaded else None
|
||||
|
||||
def run_baseline():
|
||||
recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch(
|
||||
x, topk_idx=topk_idx, topk_weights=topk_weights,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_baseline,
|
||||
num_experts=num_experts, expert_alignment=alignment,
|
||||
do_cpu_sync=False, do_handle_copy=False,
|
||||
do_expand=True, use_tma_aligned_col_major_sf=True
|
||||
do_expand=True, use_tma_aligned_col_major_sf=True,
|
||||
)
|
||||
n = recv_x[0].size(0)
|
||||
l1_y = torch.empty((n, intermediate_hidden * 2), dtype=torch.bfloat16, device='cuda')
|
||||
@@ -138,26 +184,7 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(
|
||||
l1_y, l2_weights, l2_y, handle.psum_num_recv_tokens_per_expert,
|
||||
use_psum_layout=True, recipe=(1, 1, 32))
|
||||
return ep_buffer.combine(l2_y, handle=handle)[0]
|
||||
|
||||
# Run fused mega MoE
|
||||
# NOTES: copy x into buffer before each call because debug mode zeros the entire buffer
|
||||
def run_fused():
|
||||
buffer.x[:num_tokens].copy_(x[0])
|
||||
buffer.x_sf[:num_tokens].copy_(x[1])
|
||||
buffer.topk_idx[:num_tokens].copy_(topk_idx)
|
||||
buffer.topk_weights[:num_tokens].copy_(topk_weights)
|
||||
|
||||
y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
# noinspection PyTypeChecker
|
||||
deep_gemm.fp8_fp4_mega_moe(
|
||||
y,
|
||||
transformed_l1_weights, transformed_l2_weights,
|
||||
buffer,
|
||||
activation_clamp=args.activation_clamp,
|
||||
fast_math=bool(args.fast_math)
|
||||
)
|
||||
return y
|
||||
return ep_buffer.combine(l2_y, handle=handle)[0], cumulative_local_expert_recv_stats_baseline
|
||||
|
||||
# Check correctness (must be bitwise identical)
|
||||
num_correctness_tests = 1 if args.num_correctness_tests is None else args.num_correctness_tests
|
||||
@@ -166,34 +193,36 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
dist_print('Running correctness tests:', once_in_node=True)
|
||||
for i in range(num_correctness_tests):
|
||||
create_inputs()
|
||||
assert torch.equal(run_fused(), run_baseline())
|
||||
for fused_result, baseline_result in zip(run_fused(), run_baseline()):
|
||||
assert torch.equal(fused_result, baseline_result)
|
||||
if (i + 1) % 100 == 0 or i == num_correctness_tests - 1:
|
||||
dist_print(f' > Correctness test #{i + 1}/{args.num_correctness_tests} passed', once_in_node=True)
|
||||
dist_print(f' > Correctness test #{i + 1}/{num_correctness_tests} passed', once_in_node=True)
|
||||
dist_print(once_in_node=True)
|
||||
else:
|
||||
create_inputs()
|
||||
|
||||
# Count local received tokens
|
||||
gathered_topk_idx = uneven_all_gather(topk_idx, group=group)
|
||||
num_recv_tokens = (rank_idx * num_experts_per_rank <= gathered_topk_idx) & \
|
||||
(gathered_topk_idx < (rank_idx + 1) * num_experts_per_rank)
|
||||
num_recv_tokens = num_recv_tokens.sum().item()
|
||||
gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | \
|
||||
(gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1
|
||||
num_recv_tokens = (gathered_topk_idx != -1).sum().item()
|
||||
|
||||
# Benchmark
|
||||
t_fused = bench_kineto(
|
||||
run_fused, 'mega_moe',
|
||||
barrier=lambda: ep_buffer.barrier(use_comm_stream=False) if ep_buffer else dist.barrier(),
|
||||
trace_path=None if not args.dump_profile_traces else f'{args.dump_profile_traces}/mega_moe_rank{rank_idx}.json')
|
||||
t_baseline = do_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0
|
||||
t_baseline = tilelang_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0
|
||||
|
||||
# TFLOPS: 3 matmuls (L1 left, L1 right, L2), each 2 * M * N * K
|
||||
safe_div = lambda a, b: float('nan') if b == 0 else a / b
|
||||
tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused)
|
||||
|
||||
# HBM bytes: weights (FP4 packed = 0.5 bytes) + activations (FP8 = 1 byte) + output (BF16 = 2 bytes)
|
||||
num_touched_experts = torch.unique(gathered_topk_idx.flatten()).numel() - 1 # NOTES minus 1 to exclude "-1"
|
||||
num_hbm_bytes = (
|
||||
num_experts_per_rank * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4)
|
||||
num_experts_per_rank * hidden * intermediate_hidden // 2 + # L2 weights (FP4)
|
||||
num_touched_experts * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4)
|
||||
num_touched_experts * hidden * intermediate_hidden // 2 + # L2 weights (FP4)
|
||||
num_recv_tokens * hidden + # L1 acts read (FP8)
|
||||
num_recv_tokens * intermediate_hidden + # L1 output write (FP8)
|
||||
num_recv_tokens * intermediate_hidden + # L2 acts read (FP8)
|
||||
@@ -230,7 +259,9 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Test PyTorch symmetric memory')
|
||||
|
||||
# Resource settings
|
||||
parser.add_argument('--ncu-profile-only', action='store_true', help='Only run profiling without correctness test')
|
||||
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
|
||||
|
||||
# Model settings
|
||||
|
||||
Reference in New Issue
Block a user