[fix] Remove trtllm ragged mla prefills (#36540)
Signed-off-by: Olya Kozlova <okozlova@nvidia.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
@@ -17,7 +18,7 @@ __global__ void merge_attn_states_kernel(
|
||||
const float* prefix_lse, const scalar_t* suffix_output,
|
||||
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||
const uint head_size, const uint prefix_head_stride,
|
||||
const uint output_head_stride) {
|
||||
const uint output_head_stride, const uint prefix_num_tokens) {
|
||||
using pack_128b_t = uint4;
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
@@ -43,6 +44,22 @@ __global__ void merge_attn_states_kernel(
|
||||
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
|
||||
scalar_t* output_head_ptr = output + dst_head_offset;
|
||||
|
||||
// If token_idx >= prefix_num_tokens, just copy from suffix
|
||||
if (token_idx >= prefix_num_tokens) {
|
||||
if (pack_offset < head_size) {
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
s_out_pack;
|
||||
}
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
output_lse[head_idx * num_tokens + token_idx] = s_lse;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// For tokens within prefix range, merge prefix and suffix
|
||||
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
|
||||
@@ -143,7 +160,8 @@ __global__ void merge_attn_states_kernel(
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||
num_heads, head_size, prefix_head_stride, output_head_stride); \
|
||||
num_heads, head_size, prefix_head_stride, output_head_stride, \
|
||||
prefix_num_tokens); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
@@ -157,14 +175,18 @@ __global__ void merge_attn_states_kernel(
|
||||
* @param suffix_output [n,h,d] The suffix attention states.
|
||||
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
|
||||
* states.
|
||||
* @param prefill_tokens_with_context Number of prefill tokens with context
|
||||
* For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output
|
||||
* is computed by merging prefix_output and suffix_output. For remaining tokens
|
||||
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
|
||||
* from suffix_output.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void merge_attn_states_launcher(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
void merge_attn_states_launcher(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context) {
|
||||
constexpr uint NUM_THREADS = 128;
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
@@ -174,6 +196,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
|
||||
const uint prefix_num_tokens =
|
||||
prefill_tokens_with_context.has_value()
|
||||
? static_cast<uint>(prefill_tokens_with_context.value())
|
||||
: num_tokens;
|
||||
TORCH_CHECK(prefix_num_tokens <= num_tokens,
|
||||
"prefix_num_tokens must be <= num_tokens");
|
||||
|
||||
float* output_lse_ptr = nullptr;
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
@@ -192,18 +222,17 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||
}
|
||||
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ \
|
||||
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
|
||||
prefix_lse, suffix_output, \
|
||||
suffix_lse); \
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ \
|
||||
merge_attn_states_launcher<scalar_t>( \
|
||||
output, output_lse, prefix_output, prefix_lse, suffix_output, \
|
||||
suffix_lse, prefill_tokens_with_context); \
|
||||
}
|
||||
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
void merge_attn_states(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
std::optional<int64_t> prefill_tokens_with_context = std::nullopt) {
|
||||
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
}
|
||||
|
||||
11
csrc/ops.h
11
csrc/ops.h
@@ -53,12 +53,11 @@ void paged_attention_v2(
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse);
|
||||
void merge_attn_states(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context);
|
||||
#ifndef USE_ROCM
|
||||
void convert_vertical_slash_indexes(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
|
||||
@@ -73,7 +73,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor prefix_output,"
|
||||
" Tensor prefix_lse,"
|
||||
" Tensor suffix_output,"
|
||||
" Tensor suffix_lse) -> ()");
|
||||
" Tensor suffix_lse,"
|
||||
" int!? prefill_tokens_with_context) -> ()");
|
||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
|
||||
@@ -20,7 +20,11 @@ def merge_attn_states_torch(
|
||||
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefill_tokens_with_context: int | None = None,
|
||||
):
|
||||
# Apply prefill_tokens_with_context mask if needed
|
||||
if prefill_tokens_with_context is None:
|
||||
prefill_tokens_with_context = output.shape[0]
|
||||
p_lse = prefix_lse
|
||||
s_lse = suffix_lse
|
||||
# inf -> -inf
|
||||
@@ -28,6 +32,9 @@ def merge_attn_states_torch(
|
||||
s_lse[s_lse == torch.inf] = -torch.inf
|
||||
# max_lse [NUM_HEADS, NUM_TOKENS]
|
||||
max_lse = torch.maximum(p_lse, s_lse)
|
||||
|
||||
mask = torch.ones((prefix_lse.shape[1], 1, 1), device=p_lse.device)
|
||||
mask[prefill_tokens_with_context:].fill_(0)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
p_lse_exp = torch.exp(p_lse)
|
||||
@@ -35,11 +42,16 @@ def merge_attn_states_torch(
|
||||
out_se = p_lse_exp + s_lse_exp
|
||||
if output_lse is not None:
|
||||
output_lse = torch.log(out_se) + max_lse
|
||||
output_lse[prefill_tokens_with_context:] = suffix_lse[
|
||||
prefill_tokens_with_context:
|
||||
]
|
||||
p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
|
||||
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
|
||||
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
output = prefix_output * p_scale + suffix_output * s_scale
|
||||
output.copy_(
|
||||
prefix_output * p_scale * mask + suffix_output * (s_scale * mask + (1 - mask))
|
||||
)
|
||||
return output, output_lse
|
||||
|
||||
|
||||
@@ -90,13 +102,18 @@ def generate_markdown_table():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prefill_tokens_with_context", [None, 128])
|
||||
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
|
||||
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("output_dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_merge_attn_states(
|
||||
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
|
||||
prefill_tokens_with_context: int | None,
|
||||
num_tokens: int,
|
||||
num_query_heads: int,
|
||||
head_size: int,
|
||||
output_dtype: torch.dtype,
|
||||
):
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(
|
||||
@@ -111,6 +128,7 @@ def test_merge_attn_states(
|
||||
print(
|
||||
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
|
||||
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
|
||||
f"prefill_tokens_with_context: {prefill_tokens_with_context}, "
|
||||
f"Device: {current_platform.get_device_name()}"
|
||||
)
|
||||
|
||||
@@ -164,6 +182,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -176,6 +195,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
@@ -199,6 +219,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -211,6 +232,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
@@ -231,6 +253,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -243,6 +266,7 @@ def test_merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -264,9 +264,16 @@ def merge_attn_states(
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: torch.Tensor | None = None,
|
||||
prefill_tokens_with_context: int | None = None,
|
||||
) -> None:
|
||||
torch.ops._C.merge_attn_states(
|
||||
output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse
|
||||
output,
|
||||
output_lse,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1181,6 +1181,7 @@ class MLACommonPrefillMetadata:
|
||||
padded_local_cu_seq_lens: torch.Tensor | None = None
|
||||
cu_seq_lens_lst: list[list[int]] | None = None
|
||||
chunk_size: int | None = None
|
||||
prefill_tokens_with_context: int | None = None
|
||||
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
@@ -1743,6 +1744,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
|
||||
)
|
||||
prefill_query_start_loc_cpu = (
|
||||
query_start_loc_cpu[reqs_start:] - query_start_loc_cpu[reqs_start]
|
||||
)
|
||||
|
||||
chunked_context_metadata = None
|
||||
if max_context_len_cpu > 0:
|
||||
@@ -1864,6 +1868,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
if self._use_cudnn_prefill
|
||||
else MLACommonPrefillMetadata.ChunkedContextMetadata
|
||||
)
|
||||
prefill_tokens_with_context = None
|
||||
if num_prefills_with_context_cpu > 0:
|
||||
prefill_tokens_with_context = prefill_query_start_loc_cpu[
|
||||
num_prefills_with_context_cpu
|
||||
].item()
|
||||
if self.dcp_world_size > 1:
|
||||
chunked_context_metadata = chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
@@ -1883,6 +1892,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
),
|
||||
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
||||
chunk_size=padded_local_max_context_chunk_across_ranks,
|
||||
prefill_tokens_with_context=prefill_tokens_with_context,
|
||||
)
|
||||
else:
|
||||
chunked_context_metadata = chunked_context_metadata_cls(
|
||||
@@ -1896,6 +1906,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
),
|
||||
chunk_total_token=chunk_total_token,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
prefill_tokens_with_context=prefill_tokens_with_context,
|
||||
)
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
@@ -2382,14 +2393,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||
assert prefill.workspace_buffer is not None
|
||||
|
||||
out = torch.zeros(
|
||||
out = torch.empty(
|
||||
q.shape[0],
|
||||
q.shape[1],
|
||||
v.shape[2],
|
||||
device=q.device,
|
||||
dtype=prefill.output_dtype,
|
||||
)
|
||||
prefill.workspace_buffer.fill_(0)
|
||||
|
||||
attn_out, lse = trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
@@ -2691,6 +2701,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
)
|
||||
|
||||
if has_context:
|
||||
assert prefill_metadata.chunked_context is not None
|
||||
suffix_output, suffix_lse = output_prefill
|
||||
if self.dcp_world_size > 1:
|
||||
context_output, context_lse = (
|
||||
@@ -2719,6 +2730,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
prefix_lse=context_lse,
|
||||
suffix_output=suffix_output,
|
||||
suffix_lse=suffix_lse,
|
||||
prefill_tokens_with_context=prefill_metadata.chunked_context.prefill_tokens_with_context,
|
||||
)
|
||||
else:
|
||||
output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2)
|
||||
|
||||
@@ -13,7 +13,36 @@ def merge_attn_states(
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: torch.Tensor | None = None,
|
||||
prefill_tokens_with_context: int | None = None,
|
||||
) -> None:
|
||||
"""Merge partial attention outputs from prefix (KV cache) and suffix
|
||||
(new tokens) into a single output tensor using the log-sum-exp (LSE)
|
||||
rescaling method described in section 2.2 of
|
||||
https://www.arxiv.org/pdf/2501.01005.
|
||||
|
||||
For tokens that have prefix context (token index < prefill_tokens_with_context),
|
||||
the prefix and suffix partial outputs are combined as a weighted sum.
|
||||
For tokens without prefix context, the suffix output is copied directly.
|
||||
|
||||
Args:
|
||||
output: Output tensor of shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE].
|
||||
prefix_output: Partial attention output over the prefix (KV cache),
|
||||
shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE].
|
||||
prefix_lse: Log-sum-exp values for the prefix attention,
|
||||
shape [NUM_HEADS, NUM_TOKENS].
|
||||
suffix_output: Partial attention output over the suffix (new KV),
|
||||
shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE].
|
||||
suffix_lse: Log-sum-exp values for the suffix attention,
|
||||
shape [NUM_HEADS, NUM_TOKENS].
|
||||
output_lse: Optional tensor to store the merged LSE values,
|
||||
shape [NUM_HEADS, NUM_TOKENS]. If None, LSE is not written out.
|
||||
prefill_tokens_with_context: Number of prefill tokens that have
|
||||
prefix context and therefore require merging. Tokens at indices
|
||||
>= this value are decode or context-free prefill tokens whose
|
||||
output is taken directly from suffix_output. If None, all tokens
|
||||
are treated as having context.
|
||||
"""
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
||||
# does not support FP8 dtype, fallback to use Triton kernel.
|
||||
def supported_dtypes(o: torch.Tensor) -> bool:
|
||||
@@ -37,11 +66,23 @@ def merge_attn_states(
|
||||
from vllm._custom_ops import merge_attn_states
|
||||
|
||||
return merge_attn_states(
|
||||
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
|
||||
output,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
else:
|
||||
from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
|
||||
return merge_attn_states(
|
||||
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
|
||||
output,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ def merge_attn_states(
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: torch.Tensor | None = None,
|
||||
prefill_tokens_with_context: int | None = None,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
@@ -25,6 +26,11 @@ def merge_attn_states(
|
||||
# backend.
|
||||
prefix_head_stride = prefix_output.stride(1)
|
||||
output_head_stride = output.stride(1)
|
||||
|
||||
# If prefill_tokens_with_context is None, all tokens should use prefix context
|
||||
if prefill_tokens_with_context is None:
|
||||
prefill_tokens_with_context = num_tokens
|
||||
|
||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
@@ -38,6 +44,7 @@ def merge_attn_states(
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
prefill_tokens_with_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -54,12 +61,44 @@ def merge_attn_states_kernel(
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
prefill_tokens_with_context: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
prefix_mask = token_idx < prefill_tokens_with_context
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
|
||||
# For tokens without context (token_idx >= prefill_tokens_with_context),
|
||||
# directly copy from suffix_output
|
||||
if not prefix_mask:
|
||||
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||
if OUTPUT_LSE:
|
||||
tl.store(output_lse + head_idx * num_tokens + token_idx, s_lse)
|
||||
|
||||
s_out = tl.load(
|
||||
suffix_output
|
||||
+ token_idx * num_heads * prefix_head_stride
|
||||
+ head_idx * prefix_head_stride
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
tl.store(
|
||||
output
|
||||
+ token_idx * num_heads * output_head_stride
|
||||
+ head_idx * output_head_stride
|
||||
+ head_arange,
|
||||
s_out,
|
||||
mask=head_mask,
|
||||
)
|
||||
return
|
||||
|
||||
# For tokens with context (token_idx < prefill_tokens_with_context),
|
||||
# perform normal merge operation
|
||||
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||
|
||||
@@ -83,8 +122,6 @@ def merge_attn_states_kernel(
|
||||
out_lse = tl.log(out_se) + max_lse
|
||||
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(
|
||||
prefix_output
|
||||
+ token_idx * num_heads * prefix_head_stride
|
||||
|
||||
Reference in New Issue
Block a user