diff --git a/CMakeLists.txt b/CMakeLists.txt index 168376ca1..c9b1bf54e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -293,6 +293,7 @@ set(VLLM_EXT_SRC "csrc/fused_qknorm_rope_kernel.cu" "csrc/layernorm_quant_kernels.cu" "csrc/sampler.cu" + "csrc/topk.cu" "csrc/cuda_view.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/w8a8/int8/scaled_quant.cu" diff --git a/csrc/ops.h b/csrc/ops.h index 9ee6bda31..f5dfb0ecc 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -114,6 +114,10 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, int64_t numRows, int64_t stride0, int64_t stride1, int64_t topK); +void large_context_topk(const torch::Tensor& score, torch::Tensor& indices, + const torch::Tensor& lengths, + std::optional row_starts_opt); + void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, double epsilon); diff --git a/csrc/sampler.cu b/csrc/sampler.cu index f7c091f1d..30bfef33c 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -725,4 +725,4 @@ void top_k_per_row_prefill(const torch::Tensor& logits, static_cast(stride0), static_cast(stride1), static_cast(topK), kSortingAlgorithmThreshold); } -} +} \ No newline at end of file diff --git a/csrc/topk.cu b/csrc/topk.cu new file mode 100644 index 000000000..e2702b2d0 --- /dev/null +++ b/csrc/topk.cu @@ -0,0 +1,373 @@ +// Portions of this file are adapted from SGLang PR: +// https://github.com/sgl-project/sglang/pull/11194 +// and +// https://github.com/sgl-project/sglang/pull/17747 + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#include +#include + +#ifndef USE_ROCM + #include +#else + #include +#endif + +namespace vllm { + +constexpr int TopK = 2048; // DeepSeek V3 sparse attention top-k +constexpr int kThreadsPerBlock = 1024; // Threads per block + +// Shared memory budget +#if defined(USE_ROCM) +constexpr size_t kSmem = 48 * 1024; // ROCm default: 48KB +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [batch, seq_len] Logits + const int32_t* __restrict__ row_starts; // [batch] Offset into each row + // (optional) + int32_t* __restrict__ indices; // [batch, TopK] Output top-k indices + int32_t* __restrict__ lengths; // [batch] Sequence lengths per row + int64_t input_stride; // Stride between rows +}; + +__device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ void naive_topk_cuda(const float* __restrict__ logits, + int32_t* __restrict__ output_indices, + int32_t seq_len) { + const int thread_id = threadIdx.x; + for (int i = thread_id; i < TopK; i += kThreadsPerBlock) { + output_indices[i] = (i < seq_len) ? i : -1; + } +} + +// Adapted from: +// https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87 +// by: DarkSharpness +// which at the same time is an optimized topk kernel copied from tilelang +// kernel +__device__ void fast_topk_cuda_tl( + const float* __restrict__ logits, // Input logits [seq_len] + int* __restrict__ output_indices, // Output top-k indices [TopK] + int logits_offset, // Starting offset in logits array + int seq_len) // Number of valid logits to process +{ + constexpr int RADIX = 256; + constexpr int MAX_BUFFERED_ITEMS = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int shared_histogram[2][RADIX + 128]; + alignas(128) __shared__ int shared_output_count; + alignas(128) __shared__ int shared_threshold_bin; + alignas(128) __shared__ int shared_buffered_count[2]; + + extern __shared__ int buffered_indices[][MAX_BUFFERED_ITEMS]; + + const int thread_id = threadIdx.x; + int remaining_k = TopK; + + // Pass 0: Build coarse 8-bit histogram using FP16 high bits + if (thread_id < RADIX + 1) { + shared_histogram[0][thread_id] = 0; + } + __syncthreads(); + + for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { + const auto bin = convert_to_uint8(logits[idx + logits_offset]); + ::atomicAdd(&shared_histogram[0][bin], 1); + } + __syncthreads(); + + // Helper: Compute cumulative sum (suffix sum) over histogram using ping-pong + // buffers + auto compute_cumulative_sum = [&]() { + static_assert(1 << 8 == RADIX, + "Radix must be 256 for 8 unrolled iterations"); +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + if (C10_LIKELY(thread_id < RADIX)) { + const int stride = 1 << i; + const int src_buffer = i & 1; + const int dst_buffer = src_buffer ^ 1; + + int value = shared_histogram[src_buffer][thread_id]; + if (thread_id < RADIX - stride) { + value += shared_histogram[src_buffer][thread_id + stride]; + } + shared_histogram[dst_buffer][thread_id] = value; + } + __syncthreads(); + } + }; + + compute_cumulative_sum(); + + // Find threshold bin where cumsum crosses remaining_k + if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k && + shared_histogram[0][thread_id + 1] <= remaining_k) { + shared_threshold_bin = thread_id; + shared_buffered_count[0] = 0; + shared_output_count = 0; + } + __syncthreads(); + + const int threshold_bin = shared_threshold_bin; + remaining_k -= shared_histogram[0][threshold_bin + 1]; + + // Early exit if threshold bin perfectly matches remaining_k + if (remaining_k == 0) { + for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { + const int bin = convert_to_uint8(logits[idx + logits_offset]); + if (bin > threshold_bin) { + const int output_pos = ::atomicAdd(&shared_output_count, 1); + output_indices[output_pos] = idx; + } + } + __syncthreads(); + return; + } + + // Prepare for refinement passes: Process threshold bin + __syncthreads(); + if (thread_id < RADIX + 1) { + shared_histogram[0][thread_id] = 0; + } + __syncthreads(); + + // Scan all elements and: + // 1. Write indices > threshold_bin to output + // 2. Buffer indices == threshold_bin for refinement + // 3. Build histogram for next refinement pass (fused optimization) + for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { + const float logit_value = logits[idx + logits_offset]; + const int bin = convert_to_uint8(logit_value); + + if (bin > threshold_bin) { + // in top-k, write to output + const int output_pos = ::atomicAdd(&shared_output_count, 1); + output_indices[output_pos] = idx; + } else if (bin == threshold_bin) { + // Candidate for top-k, needs refinement + const int buffer_pos = ::atomicAdd(&shared_buffered_count[0], 1); + if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) { + buffered_indices[0][buffer_pos] = idx; + // Fused: Build histogram for next pass + const uint32_t fp32_bits = convert_to_uint32_v2(logit_value); + const int next_bin = (fp32_bits >> 24) & 0xFF; + ::atomicAdd(&shared_histogram[0][next_bin], 1); + } + } + } + __syncthreads(); + + // ============================================================================ + // Passes 1-4: Refine using 8-bit passes over FP32 bits + // ============================================================================ + // FP32 bits [31:0] split into 4 bytes processed MSB-first: + // Pass 1: bits [31:24], Pass 2: bits [23:16], Pass 3: bits [15:8], Pass 4: + // bits [7:0] +#pragma unroll 4 + for (int pass = 0; pass < 4; ++pass) { + __shared__ int shared_final_k; // For final pass: remaining slots to fill + const int src_buffer = pass % 2; + const int dst_buffer = src_buffer ^ 1; + + // Clamp buffered count to prevent overflow + const int raw_buffered = shared_buffered_count[src_buffer]; + const int num_buffered = + (raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS; + + compute_cumulative_sum(); + + // Find threshold bin for this pass + if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k && + shared_histogram[0][thread_id + 1] <= remaining_k) { + shared_threshold_bin = thread_id; + shared_buffered_count[dst_buffer] = 0; + shared_final_k = remaining_k - shared_histogram[0][thread_id + 1]; + } + __syncthreads(); + + const int threshold_bin = shared_threshold_bin; + remaining_k -= shared_histogram[0][threshold_bin + 1]; + + // Bit offset for this pass: 24, 16, 8, 0 + const int bit_offset = 24 - pass * 8; + + // Early exit if threshold bin perfectly matches + if (remaining_k == 0) { + for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) { + const int idx = buffered_indices[src_buffer][i]; + const uint32_t fp32_bits = + convert_to_uint32_v2(logits[idx + logits_offset]); + const int bin = (fp32_bits >> bit_offset) & 0xFF; + if (bin > threshold_bin) { + const int output_pos = ::atomicAdd(&shared_output_count, 1); + output_indices[output_pos] = idx; + } + } + __syncthreads(); + break; + } + + // Continue refinement + __syncthreads(); + if (thread_id < RADIX + 1) { + shared_histogram[0][thread_id] = 0; + } + __syncthreads(); + + for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) { + const int idx = buffered_indices[src_buffer][i]; + const float logit_value = logits[idx + logits_offset]; + const uint32_t fp32_bits = convert_to_uint32_v2(logit_value); + const int bin = (fp32_bits >> bit_offset) & 0xFF; + + if (bin > threshold_bin) { + // Definitely in top-k + const int output_pos = ::atomicAdd(&shared_output_count, 1); + output_indices[output_pos] = idx; + } else if (bin == threshold_bin) { + if (pass == 3) { + // Final pass (bits [7:0]): No more refinement possible + // Fill remaining slots in reverse order to maintain descending order + const int slot = ::atomicAdd(&shared_final_k, -1); + if (slot > 0) { + output_indices[TopK - slot] = idx; + } + } else { + // Buffer for next pass and build next histogram + const int buffer_pos = + ::atomicAdd(&shared_buffered_count[dst_buffer], 1); + if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) { + buffered_indices[dst_buffer][buffer_pos] = idx; + // Fused: Build histogram for next pass + const int next_bit_offset = bit_offset - 8; + const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF; + ::atomicAdd(&shared_histogram[0][next_bin], 1); + } + } + } + } + __syncthreads(); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) void topk_kernel( + const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const uint64_t batch_idx = blockIdx.x; + const int logits_offset = row_starts == nullptr ? 0 : row_starts[batch_idx]; + const int seq_len = lengths[batch_idx]; + int* output_indices = indices + batch_idx * TopK; + const float* logits = input + batch_idx * input_stride; + + if (seq_len <= TopK) { + // Shortcut: All elements are in top-k + return naive_topk_cuda(logits, output_indices, seq_len); + } else { + return fast_topk_cuda_tl(logits, output_indices, logits_offset, seq_len); + } +} + +FastTopKParams get_params( + const at::Tensor& score, const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) { + const int64_t batch_size = score.size(0); + + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1, + "score must be 2D with contiguous rows"); + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous() && + lengths.size(0) == batch_size, + "lengths must be 1D contiguous with size matching batch"); + + const int32_t* row_starts_ptr = nullptr; + if (row_starts_opt.has_value()) { + const auto& row_starts = *row_starts_opt; + TORCH_CHECK(row_starts.dim() == 1 && row_starts.size(0) == batch_size, + "row_starts must be 1D with size matching batch"); + row_starts_ptr = row_starts.data_ptr(); + } + + int32_t* indices_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = *indices_opt; + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous() && + indices.size(0) == batch_size && indices.size(1) == TopK, + "indices must be 2D contiguous [batch, TopK]"); + indices_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_ptr, + .indices = indices_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + static const cudaError_t result = []() -> cudaError_t { +#ifdef USE_ROCM + auto func_ptr = reinterpret_cast(kernel_func); +#else + auto func_ptr = kernel_func; +#endif + return cudaFuncSetAttribute( + func_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + }(); + + TORCH_CHECK( + result == cudaSuccess, + "Failed to set kernel shared memory limit: ", cudaGetErrorString(result)); +} + +} // namespace vllm + +void large_context_topk( + const torch::Tensor& logits, torch::Tensor& indices, + const torch::Tensor& seq_lens, + c10::optional row_starts = c10::nullopt) { + TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor"); + TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor"); + TORCH_CHECK(seq_lens.is_cuda(), "seq_lens must be a CUDA tensor"); + if (row_starts.has_value()) { + TORCH_CHECK(row_starts->is_cuda(), "row_starts must be a CUDA tensor"); + } + + const auto params = vllm::get_params(logits, seq_lens, row_starts, indices); + const int64_t batch_size = logits.size(0); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const dim3 grid(static_cast(batch_size)); + const dim3 block(vllm::kThreadsPerBlock); + + vllm::setup_kernel_smem_once(); + vllm::topk_kernel<<>>(params); + + const cudaError_t result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "large_context_topk kernel failed: ", cudaGetErrorString(result)); +} \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 97c0e80e7..9766b15ea 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -190,6 +190,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int numRows, int stride0, int stride1, int topK) -> ()"); ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode); + ops.def( + "large_context_topk(Tensor score, Tensor indices, Tensor lengths, " + "Tensor? " + "row_starts_opt) -> ()"); + ops.impl("large_context_topk", torch::kCUDA, &large_context_topk); + // Layernorm-quant // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py index 2d9dd2a04..9b96e6dfc 100644 --- a/tests/kernels/test_top_k_per_row.py +++ b/tests/kernels/test_top_k_per_row.py @@ -275,3 +275,114 @@ def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None: _run_top_k_per_row_decode_test( top_k, batch_size, next_n, vocab_size, clean_logits, data_generation ) + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@pytest.mark.parametrize("clean_logits", [True, False]) +@torch.inference_mode() +def test_deepseek_hybrid_topk(clean_logits: bool) -> None: + torch.set_default_device("cuda:0") + + top_k = 2048 + + # Test case 1: Short sequences (< 8192) + batch_size_short = 4 + next_n = 1 + num_rows_short = batch_size_short * next_n + + # Create sequences with max length < 8192 + seq_lens_short = torch.randint( + 4000, 8000, (batch_size_short,), dtype=torch.int32, device="cuda" + ) + + row_starts_short = torch.zeros(num_rows_short, dtype=torch.int32, device="cuda") + row_indices_short = torch.arange(num_rows_short, device="cuda") // next_n + next_n_offset_short = torch.arange(num_rows_short, device="cuda") % next_n + row_ends_short = ( + seq_lens_short[row_indices_short] - next_n + next_n_offset_short + 1 + ) + + logits_short = create_random_logits( + row_starts_short, row_ends_short, torch.float32, 42, clean_logits, "random" + ) + + indices_vllm = torch.empty( + (num_rows_short, top_k), dtype=torch.int32, device="cuda" + ) + + # Use vllm's kernel for short sequences + torch.ops._C.top_k_per_row_decode( + logits_short, + next_n, + seq_lens_short, + indices_vllm, + num_rows_short, + logits_short.stride(0), + logits_short.stride(1), + top_k, + ) + + # Test case 2: Long sequences (>= 8192) - should use large_context_topk kernel + batch_size_long = 4 + num_rows_long = batch_size_long * next_n + + # Create sequences with max length >= 8192 + seq_lens_long = torch.randint( + 8192, 16384, (batch_size_long,), dtype=torch.int32, device="cuda" + ) + + row_starts_long = torch.zeros(num_rows_long, dtype=torch.int32, device="cuda") + row_indices_long = torch.arange(num_rows_long, device="cuda") // next_n + next_n_offset_long = torch.arange(num_rows_long, device="cuda") % next_n + row_ends_long = seq_lens_long[row_indices_long] - next_n + next_n_offset_long + 1 + + logits_long = create_random_logits( + row_starts_long, row_ends_long, torch.float32, 43, clean_logits, "random" + ) + + indices = torch.empty((num_rows_long, top_k), dtype=torch.int32, device="cuda") + + # Use large_context_topk kernel for long sequences + if next_n == 1: + lengths = seq_lens_long + else: + offsets = torch.arange(next_n, device=logits_long.device, dtype=torch.int32) + lengths = (seq_lens_long.unsqueeze(1) - next_n + 1 + offsets).flatten() + + torch.ops._C.large_context_topk( + logits_long, + indices, + lengths, + None, + ) + + torch_indices_short = torch.empty( + (num_rows_short, top_k), dtype=torch.int32, device="cuda" + ) + for i in range(num_rows_short): + row_end = int(row_ends_short[i]) + k_i = min(top_k, row_end) + idx = logits_short[i, :row_end].topk(k_i, dim=-1)[1] + torch_indices_short[i, :k_i] = idx + + assert compare_top_k_results( + logits_short, + indices_vllm, + torch_indices_short, + row_starts_short, + row_ends_short, + top_k, + ), "top_k_per_row_decode kernel (short sequences) doesn't match torch.topk" + + torch_indices_long = torch.empty( + (num_rows_long, top_k), dtype=torch.int32, device="cuda" + ) + for i in range(num_rows_long): + row_end = int(row_ends_long[i]) + k_i = min(top_k, row_end) + idx = logits_long[i, :row_end].topk(k_i, dim=-1)[1] + torch_indices_long[i, :k_i] = idx + + assert compare_top_k_results( + logits_long, indices, torch_indices_long, row_starts_long, row_ends_long, top_k + ), "large_context_topk kernel (long sequences) doesn't match torch.topk" diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 9ca7a42b7..bd063de74 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -126,6 +126,15 @@ def sparse_attn_indexer( topk_tokens, ) + # Compute lengths from row spans + # lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32) + # torch.ops._C.large_context_topk( + # logits, + # topk_indices, + # lengths, + # chunk.cu_seqlen_ks, # row_starts + # ) + if has_decode: decode_metadata = attn_metadata.decode # kv_cache size requirement [num_block, block_size, n_head, head_dim], @@ -162,18 +171,37 @@ def sparse_attn_indexer( ) num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - decode_metadata.seq_lens, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + + if decode_metadata.use_large_context_topk: + if next_n == 1: + lengths = decode_metadata.seq_lens + else: + # (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,) + lengths = ( + decode_metadata.seq_lens.unsqueeze(1) + - next_n + + 1 + + decode_metadata.offsets + ).flatten() + + torch.ops._C.large_context_topk( + logits, + topk_indices, + lengths, + None, + ) + else: + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) if decode_metadata.requires_padding: # if padded, we need to unpack diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 8c1ea1646..368b217f0 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -86,6 +86,8 @@ class DeepSeekV32IndexerDecodeMetadata: decode_lens: torch.Tensor requires_padding: bool schedule_metadata: torch.Tensor + use_large_context_topk: bool + offsets: torch.Tensor | None # Precomputed offsets for speculative decoding @dataclass @@ -320,6 +322,21 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): # Use CPU to avoid GPU sync; breaking async scheduling requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() + # Decide which top-k kernel to use based on batch size and sequence length + batch_size = num_decodes + _is_large_context = common_attn_metadata.max_seq_len > 8192 + + # Decision logic based on micro-benchmark results: + # - large_context_topk wins for batch <= 128 and seq_len > 8K + # - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K + use_large_context_topk = batch_size <= 128 and _is_large_context + + next_n = 1 + self.num_speculative_tokens + if next_n > 1: + offsets = torch.arange(next_n, device=self.device, dtype=torch.int32) + else: + offsets = None + seq_lens = common_attn_metadata.seq_lens[:num_decodes] if is_deep_gemm_supported(): self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( @@ -331,6 +348,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): decode_lens=decode_lens, requires_padding=requires_padding, schedule_metadata=self.scheduler_metadata_buffer, + use_large_context_topk=use_large_context_topk, + offsets=offsets, ) attn_metadata = DeepseekV32IndexerMetadata(