@@ -108,7 +108,7 @@ static torch::Tensor fp8_mqa_logits(const torch::Tensor& q,
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
if (arch_major == 9 or arch_major == 10) {
|
||||
smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits,
|
||||
seq_len, seq_len_kv, aligned_seq_len_kv, num_heads, head_dim, seq_len_alignment);
|
||||
} else {
|
||||
@@ -130,7 +130,7 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
if (arch_major == 9 or arch_major == 10) {
|
||||
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, block_kv, num_sms);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
@@ -205,7 +205,7 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
if (arch_major == 9 or arch_major == 10) {
|
||||
smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta,
|
||||
batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv,
|
||||
kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, num_math_warp_groups);
|
||||
|
||||
Reference in New Issue
Block a user