@@ -108,7 +108,7 @@ static torch::Tensor fp8_mqa_logits(const torch::Tensor& q,
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
if (arch_major == 9 or arch_major == 10) {
|
||||
smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits,
|
||||
seq_len, seq_len_kv, aligned_seq_len_kv, num_heads, head_dim, seq_len_alignment);
|
||||
} else {
|
||||
@@ -130,7 +130,7 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
if (arch_major == 9 or arch_major == 10) {
|
||||
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, block_kv, num_sms);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
@@ -205,7 +205,7 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
if (arch_major == 9 or arch_major == 10) {
|
||||
smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta,
|
||||
batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv,
|
||||
kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, num_math_warp_groups);
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
// Attention kernels
|
||||
#include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
|
||||
#include <deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh>
|
||||
#include <deep_gemm/impls/sm100_fp8_mqa_logits.cuh>
|
||||
#include <deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh>
|
||||
#include <deep_gemm/impls/smxx_clean_logits.cuh>
|
||||
|
||||
// Einsum kernels
|
||||
|
||||
@@ -58,6 +58,7 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
|
||||
// Calculate shared memory size
|
||||
const int smem_size = aligned_batch_size * static_cast<int>(sizeof(int));
|
||||
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
||||
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
|
||||
|
||||
// Launch
|
||||
const SMXXPagedMQALogitsMetadataRuntime::Args& args = {
|
||||
@@ -164,6 +165,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
const int& num_math_warp_groups) {
|
||||
const int num_specialized_threads = 128;
|
||||
const int num_math_threads = num_math_warp_groups * 128;
|
||||
const int num_extra_threads = device_runtime->get_arch_major() == 10 ? 128 : 0;
|
||||
const int num_q_stages = 3, num_kv_stages = 3;
|
||||
const int split_kv = num_math_warp_groups * block_kv;
|
||||
DG_HOST_ASSERT(logits_stride % (num_math_warp_groups * block_kv) == 0);
|
||||
@@ -183,6 +185,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
|
||||
// Calculate shared memory size
|
||||
const int swizzle_alignment = head_dim * 8;
|
||||
|
||||
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
|
||||
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
|
||||
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);
|
||||
@@ -191,8 +194,13 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast<int>(kv_cache_scales.element_size()), swizzle_alignment);
|
||||
const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment);
|
||||
|
||||
const int smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size;
|
||||
// Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90
|
||||
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
|
||||
const int smem_tmem_ptr = 4;
|
||||
|
||||
const int smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr;
|
||||
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
||||
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
|
||||
|
||||
// Launch
|
||||
const SMXXFP8PagedMQALogitsRuntime::Args& args = {
|
||||
@@ -217,7 +225,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
.num_specialized_threads = num_specialized_threads,
|
||||
.num_math_threads = num_math_threads,
|
||||
.launch_args = LaunchArgs(num_sms,
|
||||
num_specialized_threads + num_math_threads,
|
||||
num_specialized_threads + num_math_threads + num_extra_threads,
|
||||
smem_size)
|
||||
};
|
||||
const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args);
|
||||
|
||||
Reference in New Issue
Block a user