[fix] Remove trtllm ragged mla prefills (#36540)

Signed-off-by: Olya Kozlova <okozlova@nvidia.com>
This commit is contained in:
Olya Kozlova
2026-03-31 21:30:27 +02:00
committed by GitHub
parent b779eb3363
commit 598190aac3
8 changed files with 185 additions and 35 deletions

View File

@@ -3,6 +3,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <limits>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
@@ -17,7 +18,7 @@ __global__ void merge_attn_states_kernel(
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
const uint head_size, const uint prefix_head_stride,
const uint output_head_stride) {
const uint output_head_stride, const uint prefix_num_tokens) {
using pack_128b_t = uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
@@ -43,6 +44,22 @@ __global__ void merge_attn_states_kernel(
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
scalar_t* output_head_ptr = output + dst_head_offset;
// If token_idx >= prefix_num_tokens, just copy from suffix
if (token_idx >= prefix_num_tokens) {
if (pack_offset < head_size) {
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
suffix_head_ptr)[pack_offset / pack_size];
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
s_out_pack;
}
if (output_lse != nullptr && pack_idx == 0) {
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
output_lse[head_idx * num_tokens + token_idx] = s_lse;
}
return;
}
// For tokens within prefix range, merge prefix and suffix
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
@@ -143,7 +160,8 @@ __global__ void merge_attn_states_kernel(
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size, prefix_head_stride, output_head_stride); \
num_heads, head_size, prefix_head_stride, output_head_stride, \
prefix_num_tokens); \
}
/*@brief Merges the attention states from prefix and suffix
@@ -157,14 +175,18 @@ __global__ void merge_attn_states_kernel(
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
* states.
* @param prefill_tokens_with_context Number of prefill tokens with context
* For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output
* is computed by merging prefix_output and suffix_output. For remaining tokens
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
* from suffix_output.
*/
template <typename scalar_t>
void merge_attn_states_launcher(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
void merge_attn_states_launcher(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
@@ -174,6 +196,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
const uint prefix_num_tokens =
prefill_tokens_with_context.has_value()
? static_cast<uint>(prefill_tokens_with_context.value())
: num_tokens;
TORCH_CHECK(prefix_num_tokens <= num_tokens,
"prefix_num_tokens must be <= num_tokens");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
@@ -192,18 +222,17 @@ void merge_attn_states_launcher(torch::Tensor& output,
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
}
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
prefix_lse, suffix_output, \
suffix_lse); \
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>( \
output, output_lse, prefix_output, prefix_lse, suffix_output, \
suffix_lse, prefill_tokens_with_context); \
}
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
void merge_attn_states(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
std::optional<int64_t> prefill_tokens_with_context = std::nullopt) {
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
}

View File

@@ -53,12 +53,11 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
void merge_attn_states(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context);
#ifndef USE_ROCM
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]

View File

@@ -73,7 +73,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor prefix_output,"
" Tensor prefix_lse,"
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()");
" Tensor suffix_lse,"
" int!? prefill_tokens_with_context) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def(

View File

@@ -20,7 +20,11 @@ def merge_attn_states_torch(
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
prefill_tokens_with_context: int | None = None,
):
# Apply prefill_tokens_with_context mask if needed
if prefill_tokens_with_context is None:
prefill_tokens_with_context = output.shape[0]
p_lse = prefix_lse
s_lse = suffix_lse
# inf -> -inf
@@ -28,6 +32,9 @@ def merge_attn_states_torch(
s_lse[s_lse == torch.inf] = -torch.inf
# max_lse [NUM_HEADS, NUM_TOKENS]
max_lse = torch.maximum(p_lse, s_lse)
mask = torch.ones((prefix_lse.shape[1], 1, 1), device=p_lse.device)
mask[prefill_tokens_with_context:].fill_(0)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
p_lse_exp = torch.exp(p_lse)
@@ -35,11 +42,16 @@ def merge_attn_states_torch(
out_se = p_lse_exp + s_lse_exp
if output_lse is not None:
output_lse = torch.log(out_se) + max_lse
output_lse[prefill_tokens_with_context:] = suffix_lse[
prefill_tokens_with_context:
]
p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
output = prefix_output * p_scale + suffix_output * s_scale
output.copy_(
prefix_output * p_scale * mask + suffix_output * (s_scale * mask + (1 - mask))
)
return output, output_lse
@@ -90,13 +102,18 @@ def generate_markdown_table():
)
@pytest.mark.parametrize("prefill_tokens_with_context", [None, 128])
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("output_dtype", DTYPES)
@torch.inference_mode()
def test_merge_attn_states(
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
prefill_tokens_with_context: int | None,
num_tokens: int,
num_query_heads: int,
head_size: int,
output_dtype: torch.dtype,
):
if not current_platform.is_cuda():
pytest.skip(
@@ -111,6 +128,7 @@ def test_merge_attn_states(
print(
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
f"prefill_tokens_with_context: {prefill_tokens_with_context}, "
f"Device: {current_platform.get_device_name()}"
)
@@ -164,6 +182,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse_torch,
output_lse_torch,
prefill_tokens_with_context,
)
torch.accelerator.synchronize()
@@ -176,6 +195,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse_torch,
output_lse_torch,
prefill_tokens_with_context,
)
end.record()
torch.accelerator.synchronize()
@@ -199,6 +219,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse,
output_lse_ref_triton,
prefill_tokens_with_context,
)
torch.accelerator.synchronize()
@@ -211,6 +232,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse,
output_lse_ref_triton,
prefill_tokens_with_context,
)
end.record()
torch.accelerator.synchronize()
@@ -231,6 +253,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse,
output_lse_cuda,
prefill_tokens_with_context,
)
torch.accelerator.synchronize()
@@ -243,6 +266,7 @@ def test_merge_attn_states(
suffix_output,
suffix_lse,
output_lse_cuda,
prefill_tokens_with_context,
)
end.record()
torch.accelerator.synchronize()

View File

@@ -264,9 +264,16 @@ def merge_attn_states(
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: torch.Tensor | None = None,
prefill_tokens_with_context: int | None = None,
) -> None:
torch.ops._C.merge_attn_states(
output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse
output,
output_lse,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
prefill_tokens_with_context,
)

View File

@@ -1181,6 +1181,7 @@ class MLACommonPrefillMetadata:
padded_local_cu_seq_lens: torch.Tensor | None = None
cu_seq_lens_lst: list[list[int]] | None = None
chunk_size: int | None = None
prefill_tokens_with_context: int | None = None
block_table: torch.Tensor
query_start_loc: torch.Tensor
@@ -1743,6 +1744,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_query_start_loc = (
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
)
prefill_query_start_loc_cpu = (
query_start_loc_cpu[reqs_start:] - query_start_loc_cpu[reqs_start]
)
chunked_context_metadata = None
if max_context_len_cpu > 0:
@@ -1864,6 +1868,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
if self._use_cudnn_prefill
else MLACommonPrefillMetadata.ChunkedContextMetadata
)
prefill_tokens_with_context = None
if num_prefills_with_context_cpu > 0:
prefill_tokens_with_context = prefill_query_start_loc_cpu[
num_prefills_with_context_cpu
].item()
if self.dcp_world_size > 1:
chunked_context_metadata = chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
@@ -1883,6 +1892,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
prefill_tokens_with_context=prefill_tokens_with_context,
)
else:
chunked_context_metadata = chunked_context_metadata_cls(
@@ -1896,6 +1906,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
),
chunk_total_token=chunk_total_token,
workspace=self.chunked_prefill_workspace,
prefill_tokens_with_context=prefill_tokens_with_context,
)
if self._use_cudnn_prefill:
@@ -2382,14 +2393,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.workspace_buffer is not None
out = torch.zeros(
out = torch.empty(
q.shape[0],
q.shape[1],
v.shape[2],
device=q.device,
dtype=prefill.output_dtype,
)
prefill.workspace_buffer.fill_(0)
attn_out, lse = trtllm_ragged_attention_deepseek(
query=q,
@@ -2691,6 +2701,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
)
if has_context:
assert prefill_metadata.chunked_context is not None
suffix_output, suffix_lse = output_prefill
if self.dcp_world_size > 1:
context_output, context_lse = (
@@ -2719,6 +2730,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
prefix_lse=context_lse,
suffix_output=suffix_output,
suffix_lse=suffix_lse,
prefill_tokens_with_context=prefill_metadata.chunked_context.prefill_tokens_with_context,
)
else:
output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2)

View File

@@ -13,7 +13,36 @@ def merge_attn_states(
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: torch.Tensor | None = None,
prefill_tokens_with_context: int | None = None,
) -> None:
"""Merge partial attention outputs from prefix (KV cache) and suffix
(new tokens) into a single output tensor using the log-sum-exp (LSE)
rescaling method described in section 2.2 of
https://www.arxiv.org/pdf/2501.01005.
For tokens that have prefix context (token index < prefill_tokens_with_context),
the prefix and suffix partial outputs are combined as a weighted sum.
For tokens without prefix context, the suffix output is copied directly.
Args:
output: Output tensor of shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE].
prefix_output: Partial attention output over the prefix (KV cache),
shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE].
prefix_lse: Log-sum-exp values for the prefix attention,
shape [NUM_HEADS, NUM_TOKENS].
suffix_output: Partial attention output over the suffix (new KV),
shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE].
suffix_lse: Log-sum-exp values for the suffix attention,
shape [NUM_HEADS, NUM_TOKENS].
output_lse: Optional tensor to store the merged LSE values,
shape [NUM_HEADS, NUM_TOKENS]. If None, LSE is not written out.
prefill_tokens_with_context: Number of prefill tokens that have
prefix context and therefore require merging. Tokens at indices
>= this value are decode or context-free prefill tokens whose
output is taken directly from suffix_output. If None, all tokens
are treated as having context.
"""
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
# does not support FP8 dtype, fallback to use Triton kernel.
def supported_dtypes(o: torch.Tensor) -> bool:
@@ -37,11 +66,23 @@ def merge_attn_states(
from vllm._custom_ops import merge_attn_states
return merge_attn_states(
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
output,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse,
prefill_tokens_with_context,
)
else:
from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states
return merge_attn_states(
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
output,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse,
prefill_tokens_with_context,
)

View File

@@ -15,6 +15,7 @@ def merge_attn_states(
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: torch.Tensor | None = None,
prefill_tokens_with_context: int | None = None,
) -> None:
num_tokens = output.shape[0]
num_query_heads = output.shape[1]
@@ -25,6 +26,11 @@ def merge_attn_states(
# backend.
prefix_head_stride = prefix_output.stride(1)
output_head_stride = output.stride(1)
# If prefill_tokens_with_context is None, all tokens should use prefix context
if prefill_tokens_with_context is None:
prefill_tokens_with_context = num_tokens
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states_kernel[(num_tokens, num_query_heads)](
output,
@@ -38,6 +44,7 @@ def merge_attn_states(
head_size,
padded_head_size,
output_lse is not None,
prefill_tokens_with_context,
)
@@ -54,12 +61,44 @@ def merge_attn_states_kernel(
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
prefill_tokens_with_context: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)
prefix_mask = token_idx < prefill_tokens_with_context
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
# For tokens without context (token_idx >= prefill_tokens_with_context),
# directly copy from suffix_output
if not prefix_mask:
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
if OUTPUT_LSE:
tl.store(output_lse + head_idx * num_tokens + token_idx, s_lse)
s_out = tl.load(
suffix_output
+ token_idx * num_heads * prefix_head_stride
+ head_idx * prefix_head_stride
+ head_arange,
mask=head_mask,
)
tl.store(
output
+ token_idx * num_heads * output_head_stride
+ head_idx * output_head_stride
+ head_arange,
s_out,
mask=head_mask,
)
return
# For tokens with context (token_idx < prefill_tokens_with_context),
# perform normal merge operation
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
@@ -83,8 +122,6 @@ def merge_attn_states_kernel(
out_lse = tl.log(out_se) + max_lse
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
+ token_idx * num_heads * prefix_head_stride