Add SM100 kernels (#201)

Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Simon Mo
2025-09-29 02:07:28 -07:00
committed by GitHub
parent 80ceeb2c76
commit 59f2c07cf2
6 changed files with 808 additions and 10 deletions

View File

@@ -108,7 +108,7 @@ static torch::Tensor fp8_mqa_logits(const torch::Tensor& q,
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
if (arch_major == 9 or arch_major == 10) {
smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits,
seq_len, seq_len_kv, aligned_seq_len_kv, num_heads, head_dim, seq_len_alignment);
} else {
@@ -130,7 +130,7 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
if (arch_major == 9 or arch_major == 10) {
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, block_kv, num_sms);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
@@ -205,7 +205,7 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
if (arch_major == 9 or arch_major == 10) {
smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta,
batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv,
kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, num_math_warp_groups);