[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-01-23 13:04:03 -05:00
committed by GitHub
parent 6e650f56a1
commit e97f802b2d
60 changed files with 276 additions and 1365 deletions

View File

@@ -218,7 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale) {
int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) {
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
@@ -406,7 +406,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
const _B8x8 Vlocalb8 = v_ptrh8be[d];
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, *v_scale_ptr);
}
}
}
@@ -416,7 +416,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
Klocal[d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], *k_scale_ptr);
}
}
@@ -890,7 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale) {
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}
@@ -919,7 +919,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
k_scale, v_scale);
k_scale_ptr, v_scale_ptr);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
@@ -929,7 +929,7 @@ void paged_attention_custom_launcher(
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
float k_scale, float v_scale) {
torch::Tensor& k_scale, torch::Tensor& v_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -953,6 +953,8 @@ void paged_attention_custom_launcher(
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
const int max_num_partitions =
@@ -1087,7 +1089,8 @@ void paged_attention(
torch::Tensor& context_lens, // [num_seqs]
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
const int head_size = query.size(2);
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Half) {