From 598190aac38a42d8c51ea46a3061e46d9078b3a5 Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Tue, 31 Mar 2026 21:30:27 +0200 Subject: [PATCH] [fix] Remove trtllm ragged mla prefills (#36540) Signed-off-by: Olya Kozlova --- csrc/attention/merge_attn_states.cu | 67 +++++++++++++------ csrc/ops.h | 11 ++- csrc/torch_bindings.cpp | 3 +- .../attention/test_merge_attn_states.py | 28 +++++++- vllm/_custom_ops.py | 9 ++- .../layers/attention/mla_attention.py | 16 ++++- vllm/v1/attention/ops/merge_attn_states.py | 45 ++++++++++++- .../attention/ops/triton_merge_attn_states.py | 41 +++++++++++- 8 files changed, 185 insertions(+), 35 deletions(-) diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 27d1e990c..f6c1bf617 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -3,6 +3,7 @@ #include #include #include +#include #include "attention_dtypes.h" #include "attention_utils.cuh" @@ -17,7 +18,7 @@ __global__ void merge_attn_states_kernel( const float* prefix_lse, const scalar_t* suffix_output, const float* suffix_lse, const uint num_tokens, const uint num_heads, const uint head_size, const uint prefix_head_stride, - const uint output_head_stride) { + const uint output_head_stride, const uint prefix_num_tokens) { using pack_128b_t = uint4; const uint pack_size = 16 / sizeof(scalar_t); const uint threads_per_head = head_size / pack_size; @@ -43,6 +44,22 @@ __global__ void merge_attn_states_kernel( const scalar_t* suffix_head_ptr = suffix_output + src_head_offset; scalar_t* output_head_ptr = output + dst_head_offset; + // If token_idx >= prefix_num_tokens, just copy from suffix + if (token_idx >= prefix_num_tokens) { + if (pack_offset < head_size) { + pack_128b_t s_out_pack = reinterpret_cast( + suffix_head_ptr)[pack_offset / pack_size]; + reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = + s_out_pack; + } + if (output_lse != nullptr && pack_idx == 0) { + float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; + output_lse[head_idx * num_tokens + token_idx] = s_lse; + } + return; + } + + // For tokens within prefix range, merge prefix and suffix float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; p_lse = std::isinf(p_lse) ? -std::numeric_limits::infinity() : p_lse; @@ -143,7 +160,8 @@ __global__ void merge_attn_states_kernel( reinterpret_cast(prefix_lse.data_ptr()), \ reinterpret_cast(suffix_output.data_ptr()), \ reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ - num_heads, head_size, prefix_head_stride, output_head_stride); \ + num_heads, head_size, prefix_head_stride, output_head_stride, \ + prefix_num_tokens); \ } /*@brief Merges the attention states from prefix and suffix @@ -157,14 +175,18 @@ __global__ void merge_attn_states_kernel( * @param suffix_output [n,h,d] The suffix attention states. * @param suffix_lse [h,n] The log-sum-exp values for the suffix attention * states. + * @param prefill_tokens_with_context Number of prefill tokens with context + * For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output + * is computed by merging prefix_output and suffix_output. For remaining tokens + * (prefill_tokens_with_context <= token_idx < n), output is copied directly + * from suffix_output. */ template -void merge_attn_states_launcher(torch::Tensor& output, - std::optional output_lse, - const torch::Tensor& prefix_output, - const torch::Tensor& prefix_lse, - const torch::Tensor& suffix_output, - const torch::Tensor& suffix_lse) { +void merge_attn_states_launcher( + torch::Tensor& output, std::optional output_lse, + const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, + const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse, + const std::optional prefill_tokens_with_context) { constexpr uint NUM_THREADS = 128; const uint num_tokens = output.size(0); const uint num_heads = output.size(1); @@ -174,6 +196,14 @@ void merge_attn_states_launcher(torch::Tensor& output, const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); + + const uint prefix_num_tokens = + prefill_tokens_with_context.has_value() + ? static_cast(prefill_tokens_with_context.value()) + : num_tokens; + TORCH_CHECK(prefix_num_tokens <= num_tokens, + "prefix_num_tokens must be <= num_tokens"); + float* output_lse_ptr = nullptr; if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); @@ -192,18 +222,17 @@ void merge_attn_states_launcher(torch::Tensor& output, LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); } -#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ - { \ - merge_attn_states_launcher(output, output_lse, prefix_output, \ - prefix_lse, suffix_output, \ - suffix_lse); \ +#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ + { \ + merge_attn_states_launcher( \ + output, output_lse, prefix_output, prefix_lse, suffix_output, \ + suffix_lse, prefill_tokens_with_context); \ } -void merge_attn_states(torch::Tensor& output, - std::optional output_lse, - const torch::Tensor& prefix_output, - const torch::Tensor& prefix_lse, - const torch::Tensor& suffix_output, - const torch::Tensor& suffix_lse) { +void merge_attn_states( + torch::Tensor& output, std::optional output_lse, + const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, + const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse, + std::optional prefill_tokens_with_context = std::nullopt) { DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); } diff --git a/csrc/ops.h b/csrc/ops.h index 1fdd77f73..e7886633e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -53,12 +53,11 @@ void paged_attention_v2( const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); -void merge_attn_states(torch::Tensor& output, - std::optional output_lse, - const torch::Tensor& prefix_output, - const torch::Tensor& prefix_lse, - const torch::Tensor& suffix_output, - const torch::Tensor& suffix_lse); +void merge_attn_states( + torch::Tensor& output, std::optional output_lse, + const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, + const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse, + const std::optional prefill_tokens_with_context); #ifndef USE_ROCM void convert_vertical_slash_indexes( torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 85605458f..4f42477b2 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -73,7 +73,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor prefix_output," " Tensor prefix_lse," " Tensor suffix_output," - " Tensor suffix_lse) -> ()"); + " Tensor suffix_lse," + " int!? prefill_tokens_with_context) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); #ifndef USE_ROCM ops.def( diff --git a/tests/kernels/attention/test_merge_attn_states.py b/tests/kernels/attention/test_merge_attn_states.py index 6fccb8ccf..c1b71d93e 100644 --- a/tests/kernels/attention/test_merge_attn_states.py +++ b/tests/kernels/attention/test_merge_attn_states.py @@ -20,7 +20,11 @@ def merge_attn_states_torch( suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS] + prefill_tokens_with_context: int | None = None, ): + # Apply prefill_tokens_with_context mask if needed + if prefill_tokens_with_context is None: + prefill_tokens_with_context = output.shape[0] p_lse = prefix_lse s_lse = suffix_lse # inf -> -inf @@ -28,6 +32,9 @@ def merge_attn_states_torch( s_lse[s_lse == torch.inf] = -torch.inf # max_lse [NUM_HEADS, NUM_TOKENS] max_lse = torch.maximum(p_lse, s_lse) + + mask = torch.ones((prefix_lse.shape[1], 1, 1), device=p_lse.device) + mask[prefill_tokens_with_context:].fill_(0) p_lse = p_lse - max_lse s_lse = s_lse - max_lse p_lse_exp = torch.exp(p_lse) @@ -35,11 +42,16 @@ def merge_attn_states_torch( out_se = p_lse_exp + s_lse_exp if output_lse is not None: output_lse = torch.log(out_se) + max_lse + output_lse[prefill_tokens_with_context:] = suffix_lse[ + prefill_tokens_with_context: + ] p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] - output = prefix_output * p_scale + suffix_output * s_scale + output.copy_( + prefix_output * p_scale * mask + suffix_output * (s_scale * mask + (1 - mask)) + ) return output, output_lse @@ -90,13 +102,18 @@ def generate_markdown_table(): ) +@pytest.mark.parametrize("prefill_tokens_with_context", [None, 128]) @pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS) @pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("output_dtype", DTYPES) @torch.inference_mode() def test_merge_attn_states( - num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype + prefill_tokens_with_context: int | None, + num_tokens: int, + num_query_heads: int, + head_size: int, + output_dtype: torch.dtype, ): if not current_platform.is_cuda(): pytest.skip( @@ -111,6 +128,7 @@ def test_merge_attn_states( print( f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, " f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, " + f"prefill_tokens_with_context: {prefill_tokens_with_context}, " f"Device: {current_platform.get_device_name()}" ) @@ -164,6 +182,7 @@ def test_merge_attn_states( suffix_output, suffix_lse_torch, output_lse_torch, + prefill_tokens_with_context, ) torch.accelerator.synchronize() @@ -176,6 +195,7 @@ def test_merge_attn_states( suffix_output, suffix_lse_torch, output_lse_torch, + prefill_tokens_with_context, ) end.record() torch.accelerator.synchronize() @@ -199,6 +219,7 @@ def test_merge_attn_states( suffix_output, suffix_lse, output_lse_ref_triton, + prefill_tokens_with_context, ) torch.accelerator.synchronize() @@ -211,6 +232,7 @@ def test_merge_attn_states( suffix_output, suffix_lse, output_lse_ref_triton, + prefill_tokens_with_context, ) end.record() torch.accelerator.synchronize() @@ -231,6 +253,7 @@ def test_merge_attn_states( suffix_output, suffix_lse, output_lse_cuda, + prefill_tokens_with_context, ) torch.accelerator.synchronize() @@ -243,6 +266,7 @@ def test_merge_attn_states( suffix_output, suffix_lse, output_lse_cuda, + prefill_tokens_with_context, ) end.record() torch.accelerator.synchronize() diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ea54aaa95..c55f5b923 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -264,9 +264,16 @@ def merge_attn_states( suffix_output: torch.Tensor, suffix_lse: torch.Tensor, output_lse: torch.Tensor | None = None, + prefill_tokens_with_context: int | None = None, ) -> None: torch.ops._C.merge_attn_states( - output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse + output, + output_lse, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + prefill_tokens_with_context, ) diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index c77fd09de..4977c62b9 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1181,6 +1181,7 @@ class MLACommonPrefillMetadata: padded_local_cu_seq_lens: torch.Tensor | None = None cu_seq_lens_lst: list[list[int]] | None = None chunk_size: int | None = None + prefill_tokens_with_context: int | None = None block_table: torch.Tensor query_start_loc: torch.Tensor @@ -1743,6 +1744,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): prefill_query_start_loc = ( query_start_loc[reqs_start:] - query_start_loc[reqs_start] ) + prefill_query_start_loc_cpu = ( + query_start_loc_cpu[reqs_start:] - query_start_loc_cpu[reqs_start] + ) chunked_context_metadata = None if max_context_len_cpu > 0: @@ -1864,6 +1868,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self._use_cudnn_prefill else MLACommonPrefillMetadata.ChunkedContextMetadata ) + prefill_tokens_with_context = None + if num_prefills_with_context_cpu > 0: + prefill_tokens_with_context = prefill_query_start_loc_cpu[ + num_prefills_with_context_cpu + ].item() if self.dcp_world_size > 1: chunked_context_metadata = chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), @@ -1883,6 +1892,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ), cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), chunk_size=padded_local_max_context_chunk_across_ranks, + prefill_tokens_with_context=prefill_tokens_with_context, ) else: chunked_context_metadata = chunked_context_metadata_cls( @@ -1896,6 +1906,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ), chunk_total_token=chunk_total_token, workspace=self.chunked_prefill_workspace, + prefill_tokens_with_context=prefill_tokens_with_context, ) if self._use_cudnn_prefill: @@ -2382,14 +2393,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.workspace_buffer is not None - out = torch.zeros( + out = torch.empty( q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=prefill.output_dtype, ) - prefill.workspace_buffer.fill_(0) attn_out, lse = trtllm_ragged_attention_deepseek( query=q, @@ -2691,6 +2701,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ) if has_context: + assert prefill_metadata.chunked_context is not None suffix_output, suffix_lse = output_prefill if self.dcp_world_size > 1: context_output, context_lse = ( @@ -2719,6 +2730,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): prefix_lse=context_lse, suffix_output=suffix_output, suffix_lse=suffix_lse, + prefill_tokens_with_context=prefill_metadata.chunked_context.prefill_tokens_with_context, ) else: output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2) diff --git a/vllm/v1/attention/ops/merge_attn_states.py b/vllm/v1/attention/ops/merge_attn_states.py index 673d2d947..270f65d5e 100644 --- a/vllm/v1/attention/ops/merge_attn_states.py +++ b/vllm/v1/attention/ops/merge_attn_states.py @@ -13,7 +13,36 @@ def merge_attn_states( suffix_output: torch.Tensor, suffix_lse: torch.Tensor, output_lse: torch.Tensor | None = None, + prefill_tokens_with_context: int | None = None, ) -> None: + """Merge partial attention outputs from prefix (KV cache) and suffix + (new tokens) into a single output tensor using the log-sum-exp (LSE) + rescaling method described in section 2.2 of + https://www.arxiv.org/pdf/2501.01005. + + For tokens that have prefix context (token index < prefill_tokens_with_context), + the prefix and suffix partial outputs are combined as a weighted sum. + For tokens without prefix context, the suffix output is copied directly. + + Args: + output: Output tensor of shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]. + prefix_output: Partial attention output over the prefix (KV cache), + shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]. + prefix_lse: Log-sum-exp values for the prefix attention, + shape [NUM_HEADS, NUM_TOKENS]. + suffix_output: Partial attention output over the suffix (new KV), + shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]. + suffix_lse: Log-sum-exp values for the suffix attention, + shape [NUM_HEADS, NUM_TOKENS]. + output_lse: Optional tensor to store the merged LSE values, + shape [NUM_HEADS, NUM_TOKENS]. If None, LSE is not written out. + prefill_tokens_with_context: Number of prefill tokens that have + prefix context and therefore require merging. Tokens at indices + >= this value are decode or context-free prefill tokens whose + output is taken directly from suffix_output. If None, all tokens + are treated as having context. + """ + # NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel # does not support FP8 dtype, fallback to use Triton kernel. def supported_dtypes(o: torch.Tensor) -> bool: @@ -37,11 +66,23 @@ def merge_attn_states( from vllm._custom_ops import merge_attn_states return merge_attn_states( - output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse + output, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse, + prefill_tokens_with_context, ) else: from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states return merge_attn_states( - output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse + output, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse, + prefill_tokens_with_context, ) diff --git a/vllm/v1/attention/ops/triton_merge_attn_states.py b/vllm/v1/attention/ops/triton_merge_attn_states.py index 74e4d778d..f5b4fbe0b 100644 --- a/vllm/v1/attention/ops/triton_merge_attn_states.py +++ b/vllm/v1/attention/ops/triton_merge_attn_states.py @@ -15,6 +15,7 @@ def merge_attn_states( suffix_output: torch.Tensor, suffix_lse: torch.Tensor, output_lse: torch.Tensor | None = None, + prefill_tokens_with_context: int | None = None, ) -> None: num_tokens = output.shape[0] num_query_heads = output.shape[1] @@ -25,6 +26,11 @@ def merge_attn_states( # backend. prefix_head_stride = prefix_output.stride(1) output_head_stride = output.stride(1) + + # If prefill_tokens_with_context is None, all tokens should use prefix context + if prefill_tokens_with_context is None: + prefill_tokens_with_context = num_tokens + # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. merge_attn_states_kernel[(num_tokens, num_query_heads)]( output, @@ -38,6 +44,7 @@ def merge_attn_states( head_size, padded_head_size, output_lse is not None, + prefill_tokens_with_context, ) @@ -54,12 +61,44 @@ def merge_attn_states_kernel( HEAD_SIZE: tl.constexpr, PADDED_HEAD_SIZE: tl.constexpr, OUTPUT_LSE: tl.constexpr, + prefill_tokens_with_context: tl.constexpr, ): token_idx = tl.program_id(0) num_tokens = tl.num_programs(0) head_idx = tl.program_id(1) num_heads = tl.num_programs(1) + prefix_mask = token_idx < prefill_tokens_with_context + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + + # For tokens without context (token_idx >= prefill_tokens_with_context), + # directly copy from suffix_output + if not prefix_mask: + s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + if OUTPUT_LSE: + tl.store(output_lse + head_idx * num_tokens + token_idx, s_lse) + + s_out = tl.load( + suffix_output + + token_idx * num_heads * prefix_head_stride + + head_idx * prefix_head_stride + + head_arange, + mask=head_mask, + ) + tl.store( + output + + token_idx * num_heads * output_head_stride + + head_idx * output_head_stride + + head_arange, + s_out, + mask=head_mask, + ) + return + + # For tokens with context (token_idx < prefill_tokens_with_context), + # perform normal merge operation p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) @@ -83,8 +122,6 @@ def merge_attn_states_kernel( out_lse = tl.log(out_se) + max_lse tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) - head_arange = tl.arange(0, PADDED_HEAD_SIZE) - head_mask = head_arange < HEAD_SIZE p_out = tl.load( prefix_output + token_idx * num_heads * prefix_head_stride