Files
vllm/csrc/cache.h

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

94 lines
4.3 KiB
C
Raw Normal View History

2023-12-14 12:35:58 -05:00
#pragma once
#include <torch/all.h>
#include <c10/util/Optional.h>
2023-02-16 07:47:03 +00:00
#include <map>
#include <vector>
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping);
2023-02-16 07:47:03 +00:00
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
2023-02-18 19:22:57 +00:00
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& scale);
// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla
void concat_and_cache_mla_rope_fused(
torch::Tensor& positions, torch::Tensor& q_pe, torch::Tensor& k_pe,
torch::Tensor& kv_c, torch::Tensor& rope_cos_sin_cache, bool rope_is_neox,
torch::Tensor& kv_cache_slot_mapping, torch::Tensor& kv_cache,
const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale);
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional<torch::Tensor> seq_starts = std::nullopt);
// TODO(hc): cp_gather_cache need support scaled kvcahe in the future.
void cp_gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
// Gather and upconvert FP8 KV cache to BF16 workspace
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size);
// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);
// Concatenate query nope and rope for MLA/DSA attention
void concat_mla_q(
torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::Tensor& q_out); // [num_tokens, num_heads, nope_dim + rope_dim]
// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]