[fix] Remove trtllm ragged mla prefills (#36540)

Signed-off-by: Olya Kozlova <okozlova@nvidia.com>
This commit is contained in:
Olya Kozlova
2026-03-31 21:30:27 +02:00
committed by GitHub
parent b779eb3363
commit 598190aac3
8 changed files with 185 additions and 35 deletions

View File

@@ -3,6 +3,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <limits>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
@@ -17,7 +18,7 @@ __global__ void merge_attn_states_kernel(
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
const uint head_size, const uint prefix_head_stride,
const uint output_head_stride) {
const uint output_head_stride, const uint prefix_num_tokens) {
using pack_128b_t = uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
@@ -43,6 +44,22 @@ __global__ void merge_attn_states_kernel(
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
scalar_t* output_head_ptr = output + dst_head_offset;
// If token_idx >= prefix_num_tokens, just copy from suffix
if (token_idx >= prefix_num_tokens) {
if (pack_offset < head_size) {
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
suffix_head_ptr)[pack_offset / pack_size];
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
s_out_pack;
}
if (output_lse != nullptr && pack_idx == 0) {
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
output_lse[head_idx * num_tokens + token_idx] = s_lse;
}
return;
}
// For tokens within prefix range, merge prefix and suffix
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
@@ -143,7 +160,8 @@ __global__ void merge_attn_states_kernel(
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size, prefix_head_stride, output_head_stride); \
num_heads, head_size, prefix_head_stride, output_head_stride, \
prefix_num_tokens); \
}
/*@brief Merges the attention states from prefix and suffix
@@ -157,14 +175,18 @@ __global__ void merge_attn_states_kernel(
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
* states.
* @param prefill_tokens_with_context Number of prefill tokens with context
* For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output
* is computed by merging prefix_output and suffix_output. For remaining tokens
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
* from suffix_output.
*/
template <typename scalar_t>
void merge_attn_states_launcher(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
void merge_attn_states_launcher(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
@@ -174,6 +196,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
const uint prefix_num_tokens =
prefill_tokens_with_context.has_value()
? static_cast<uint>(prefill_tokens_with_context.value())
: num_tokens;
TORCH_CHECK(prefix_num_tokens <= num_tokens,
"prefix_num_tokens must be <= num_tokens");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
@@ -192,18 +222,17 @@ void merge_attn_states_launcher(torch::Tensor& output,
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
}
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
prefix_lse, suffix_output, \
suffix_lse); \
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>( \
output, output_lse, prefix_output, prefix_lse, suffix_output, \
suffix_lse, prefill_tokens_with_context); \
}
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
void merge_attn_states(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
std::optional<int64_t> prefill_tokens_with_context = std::nullopt) {
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
}