diff --git a/csrc/cache.h b/csrc/cache.h index b162a4a2b..f2a5ec0ac 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -41,11 +41,12 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); void gather_and_maybe_dequant_cache( - torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] - torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] - torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] - torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS] + int64_t num_tokens, const std::string& kv_cache_dtype, torch::Tensor const& scale, std::optional seq_starts = std::nullopt); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 32960cc80..8a5457206 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -905,91 +905,79 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, namespace vllm { // grid is launched with dimensions (batch, num_splits) -template +template __global__ void gather_and_maybe_dequant_cache( - const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, - // ENTRIES...] - scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] - const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] - const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] - const int32_t block_size, const int32_t entry_size, + const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRIES...] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t* __restrict__ token_to_seq, // [MAX_TOKEN_ACROSS_CHUNK] + const int32_t num_tokens, const int32_t block_size, const int64_t block_table_stride, const int64_t cache_block_stride, const int64_t cache_entry_stride, const int64_t dst_entry_stride, const float* __restrict__ scale, const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per // batch + constexpr int vec_size = sizeof(float4) / sizeof(scalar_t); + using ltype = vllm::vec_n_t; + using stype = vllm::vec_n_t; + // We are adding this for code readability which will be optimized out when + // build in release. + assert(CTA_SIZE == blockDim.x); - const int64_t bid = blockIdx.x; // Batch ID - const int32_t num_splits = gridDim.y; - const int32_t split = blockIdx.y; - const int32_t seq_start = cu_seq_lens[bid]; - const int32_t seq_end = cu_seq_lens[bid + 1]; - const int32_t seq_len = seq_end - seq_start; - const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size); - const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits); +#pragma unroll + for (int token_id = blockIdx.x; token_id < num_tokens; + token_id += gridDim.x) { + int64_t batch_id = token_to_seq[token_id]; + int64_t batch_start = cu_seq_lens[batch_id]; + int64_t batch_end = cu_seq_lens[batch_id + 1]; + int32_t batch_offset = token_id - batch_start; - const int32_t split_start = split * split_blocks; - const int32_t split_end = min((split + 1) * split_blocks, tot_blocks); + if (token_id >= batch_end) return; + int32_t offset = 0; + if (seq_starts != nullptr) { + offset = seq_starts[batch_id]; + } + batch_offset += offset; + int32_t block_table_id = batch_offset / block_size; + int32_t slot_id = batch_offset % block_size; + int32_t block_table_offset = batch_id * block_table_stride + block_table_id; + int32_t block_id = block_table[block_table_offset]; + int64_t cache_offset = + block_id * cache_block_stride + slot_id * cache_entry_stride; + constexpr int32_t vec_iter_cnt = ENTRY_SIZE / vec_size; + scalar_t* dst_ = dst + token_id * dst_entry_stride; + cache_t* src_ = const_cast(src_cache) + cache_offset; - const bool is_active_split = (split_start < tot_blocks); - const bool is_last_split = (split_end == tot_blocks); - - if (!is_active_split) return; - - int32_t full_blocks_end = split_end; - int32_t partial_block_size = 0; - - // Adjust the pointer for the block_table for this batch. - // If seq_starts is provided, compute an offset based on (seq_starts[bid] / - // page_size) - const int32_t batch_offset = bid * block_table_stride; - int32_t offset = 0; - if (seq_starts != nullptr) { - offset = seq_starts[bid] / block_size; - } - const int32_t* batch_block_table = block_table + batch_offset + offset; - - // Adjust dst pointer based on the cumulative sequence lengths. - dst += seq_start * dst_entry_stride; - - if (is_last_split) { - partial_block_size = seq_len % block_size; - if (partial_block_size) full_blocks_end -= 1; - } - - auto copy_entry = [&](const cache_t* __restrict__ _src, - scalar_t* __restrict__ _dst) { - for (int i = threadIdx.x; i < entry_size; i += blockDim.x) { +#pragma unroll + for (int idx = threadIdx.x; idx < vec_iter_cnt; idx += CTA_SIZE) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { - _dst[i] = static_cast(_src[i]); + reinterpret_cast(dst_)[idx] = + static_cast(reinterpret_cast(src_)[idx]); } else { - _dst[i] = - fp8::scaled_convert(_src[i], *scale); + ltype loaded_val = reinterpret_cast(src_)[idx]; + stype store_val; +#pragma unroll + for (int j = 0; j < vec_size; ++j) { + store_val.val[j] = fp8::scaled_convert( + loaded_val.val[j], *scale); + } + reinterpret_cast(dst_)[idx] = store_val; } } - }; - - const auto loop_end = - std::min((int64_t)full_blocks_end, block_table_stride - offset); - for (int pid = split_start; pid < loop_end; ++pid) { - auto block_id = batch_block_table[pid]; - auto block_start_ptr = src_cache + block_id * cache_block_stride; - auto block_dst_ptr = dst + pid * block_size * dst_entry_stride; - for (int eid = 0; eid < block_size; ++eid) { - copy_entry(block_start_ptr + eid * cache_entry_stride, - block_dst_ptr + eid * dst_entry_stride); - } - } - - if (partial_block_size) { - if (offset + full_blocks_end < block_table_stride) { - auto block_id = batch_block_table[full_blocks_end]; - auto block_start_ptr = src_cache + block_id * cache_block_stride; - auto block_dst_ptr = - dst + full_blocks_end * block_size * dst_entry_stride; - for (int eid = 0; eid < partial_block_size; ++eid) { - copy_entry(block_start_ptr + eid * cache_entry_stride, - block_dst_ptr + eid * dst_entry_stride); + // process tail + constexpr int32_t tail_cnt = ENTRY_SIZE % vec_size; + dst_ = dst_ + ENTRY_SIZE - tail_cnt; + src_ = src_ + ENTRY_SIZE - tail_cnt; +#pragma unroll + for (int idx = threadIdx.x; idx < tail_cnt; idx += CTA_SIZE) { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst_[idx] = static_cast(src_[idx]); + } else { + dst_[idx] = + fp8::scaled_convert(src_[idx], *scale); } } } @@ -1001,34 +989,38 @@ __global__ void gather_and_maybe_dequant_cache( // SCALAR_T is the data type of the destination tensor. // CACHE_T is the stored data type of kv-cache. // KV_DTYPE is the real data type of kv-cache. -#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ - vllm::gather_and_maybe_dequant_cache \ - <<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst.data_ptr()), \ - block_table.data_ptr(), cu_seq_lens.data_ptr(), \ - block_size, entry_size, block_table_stride, cache_block_stride, \ - cache_entry_stride, dst_entry_stride, \ - reinterpret_cast(scale.data_ptr()), seq_starts_ptr); +#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ + vllm::gather_and_maybe_dequant_cache \ + <<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + token_to_seq.data_ptr(), num_tokens, block_size, \ + block_table_stride, cache_block_stride, cache_entry_stride, \ + dst_entry_stride, reinterpret_cast(scale.data_ptr()), \ + seq_starts_ptr); // Gather sequences from the cache into the destination tensor. // - cu_seq_lens contains the cumulative sequence lengths for each batch // - block_table contains the cache block indices for each sequence +// - token_to_seq contains the back mapping from token_id to batch_id // - Optionally, seq_starts (if provided) offsets the starting block index by // (seq_starts[bid] / page_size) void gather_and_maybe_dequant_cache( - torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] - torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] - torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] - torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS] + int64_t num_tokens, const std::string& kv_cache_dtype, torch::Tensor const& scale, std::optional seq_starts = std::nullopt) { at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int32_t block_size = src_cache.size(1); - int32_t entry_size = src_cache.flatten(2, -1).size(2); + int32_t head_dim = dst.size(-1); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must be int32"); @@ -1038,6 +1030,9 @@ void gather_and_maybe_dequant_cache( TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, "seq_starts must be int32"); } + TORCH_CHECK(head_dim == 576, + "gather_and_maybe_dequant_cache only support the head_dim to 576 " + "for better performance") TORCH_CHECK(src_cache.device() == dst.device(), "src_cache and dst must be on the same device"); @@ -1055,10 +1050,9 @@ void gather_and_maybe_dequant_cache( int64_t cache_entry_stride = src_cache.stride(1); int64_t dst_entry_stride = dst.stride(0); - // Decide on the number of splits based on the batch size. - int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; - dim3 grid(batch_size, num_splits); - dim3 block(1024); + constexpr int32_t thread_block_size = 64; + dim3 grid(num_tokens); + dim3 block(thread_block_size); const int32_t* seq_starts_ptr = seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 5af74c2c2..14913bef1 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -695,7 +695,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.def( "gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, " " Tensor block_table, Tensor cu_seq_lens, " - " int batch_size, " + " Tensor token_to_seq, " + " int num_tokens, " " str kv_cache_dtype, " " Tensor scale, Tensor? seq_starts) -> ()"); cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA, diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 028e164cb..acf46d75d 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -921,12 +921,16 @@ def test_gather_and_maybe_dequant_cache_mla( ) _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) - seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device) + seq_len_tensor = torch.randint( + max_seq_len, max_seq_len + 1, (batch_size,), device=device + ) total_tokens = seq_len_tensor.sum() cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device) cu_seq_lens[0] = 0 cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) + token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device) + token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor) print("seq_len_tensor", seq_len_tensor) tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size @@ -977,7 +981,8 @@ def test_gather_and_maybe_dequant_cache_mla( dst, block_table, cu_seq_lens, - batch_size, + token_to_seq, + total_tokens, kv_cache_dtype, scale, None, @@ -990,7 +995,8 @@ def test_gather_and_maybe_dequant_cache_mla( dst, block_table, cu_seq_lens, - batch_size, + token_to_seq, + total_tokens, kv_cache_dtype, scale, None, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0f625a794..4a1bcc761 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2201,7 +2201,8 @@ def gather_and_maybe_dequant_cache( dst: torch.Tensor, block_table: torch.Tensor, cu_seq_lens: torch.Tensor, - batch_size: int, + token_to_seq: torch.Tensor, + num_tokens: int, kv_cache_dtype: str, scale: torch.Tensor, seq_starts: torch.Tensor | None = None, @@ -2211,7 +2212,8 @@ def gather_and_maybe_dequant_cache( dst, block_table, cu_seq_lens, - batch_size, + token_to_seq, + num_tokens, kv_cache_dtype, scale, seq_starts, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 43aef8a7c..87a3aac21 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -340,6 +340,8 @@ class MLACommonPrefillMetadata: max_seq_lens: list[int] seq_lens: torch.Tensor workspace: torch.Tensor + token_to_seq: torch.Tensor + chunk_total_token: list[int] # for mla DCP padded_local_chunk_seq_lens: list[list[int]] | None = None @@ -839,6 +841,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): torch.cumsum( chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 ) + chunk_total_token = cu_seq_lens_cpu[:, -1] + + max_token_num_over_chunk = chunk_total_token.max().item() + token_to_seq_tensor_cpu = torch.zeros( + [num_chunks, max_token_num_over_chunk], dtype=torch.int32 + ) + range_idx = torch.arange(num_prefills, dtype=torch.int32) + for i in range(num_chunks): + chunk_token_to_seq_tensor = torch.repeat_interleave( + range_idx, chunk_seq_lens[i] + ) + chunk_len = chunk_token_to_seq_tensor.shape[0] + token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor if self.dcp_world_size > 1: local_context_lens_allranks = get_dcp_local_seq_lens( @@ -906,6 +921,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), seq_lens=chunk_seq_lens, + token_to_seq=token_to_seq_tensor_cpu.to( + device, non_blocking=True + ), + chunk_total_token=chunk_total_token.tolist(), workspace=self.chunked_prefill_workspace, padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), local_context_lens_allranks=local_context_lens_allranks.tolist(), @@ -922,6 +941,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), seq_lens=chunk_seq_lens, + token_to_seq=token_to_seq_tensor_cpu.to( + device, non_blocking=True + ), + chunk_total_token=chunk_total_token, workspace=self.chunked_prefill_workspace, ) @@ -1638,16 +1661,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): output = None iters = len(prefill_metadata.chunked_context.seq_tot) workspace = prefill_metadata.chunked_context.workspace - for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], - batch_size=attn_metadata.num_prefills, + token_to_seq=prefill_metadata.chunked_context.token_to_seq[i], + num_tokens=prefill_metadata.chunked_context.chunk_total_token[i], kv_cache_dtype=self.kv_cache_dtype, scale=k_scale, seq_starts=prefill_metadata.chunked_context.starts[i],