Support FP8-E5M2 KV Cache (#2279)

Co-authored-by: zhaoyang <zhao.yang16@zte.com.cn>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
zhaoyang-star
2024-01-29 08:43:54 +08:00
committed by GitHub
parent 7d648418b8
commit 9090bf02e7
26 changed files with 912 additions and 196 deletions

View File

@@ -25,6 +25,7 @@
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#include <algorithm>
@@ -79,17 +80,19 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Grid: (num_heads, num_seqs, max_num_partitions).
template<
typename scalar_t,
typename cache_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_E5M2_KV_CACHE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
@@ -145,6 +148,9 @@ __device__ void paged_attention_kernel(
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
#ifdef ENABLE_FP8_E5M2
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
#endif
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
@@ -176,7 +182,7 @@ __device__ void paged_attention_kernel(
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(scalar_t);
constexpr int x = 16 / sizeof(cache_t);
float qk_max = -FLT_MAX;
// Iterate over the key blocks.
@@ -202,13 +208,23 @@ __device__ void paged_attention_kernel(
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ physical_block_offset * x;
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
if constexpr (IS_FP8_E5M2_KV_CACHE) {
#ifdef ENABLE_FP8_E5M2
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec.
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
#else
assert(false);
#endif
} else {
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
}
}
// Compute dot product.
@@ -282,6 +298,9 @@ __device__ void paged_attention_kernel(
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
#ifdef ENABLE_FP8_E5M2
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
#endif
using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
@@ -307,14 +326,25 @@ __device__ void paged_attention_kernel(
L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
V_vec v_vec;
if constexpr (IS_FP8_E5M2_KV_CACHE) {
#ifdef ENABLE_FP8_E5M2
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
#else
assert(false);
#endif
} else {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
}
if (block_idx == num_context_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
@@ -395,14 +425,16 @@ __device__ void paged_attention_kernel(
// Grid: (num_heads, num_seqs, 1).
template<
typename scalar_t,
typename cache_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS>
int NUM_THREADS,
bool IS_FP8_E5M2_KV_CACHE>
__global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
@@ -412,7 +444,7 @@ __global__ void paged_attention_v1_kernel(
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
@@ -421,17 +453,19 @@ __global__ void paged_attention_v1_kernel(
// Grid: (num_heads, num_seqs, max_num_partitions).
template<
typename scalar_t,
typename cache_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_E5M2_KV_CACHE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
@@ -441,7 +475,7 @@ __global__ void paged_attention_v2_kernel(
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride);
@@ -550,10 +584,10 @@ __global__ void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
shared_mem_size); \
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
query_ptr, \
key_cache_ptr, \
@@ -571,7 +605,9 @@ __global__ void paged_attention_v2_reduce_kernel(
// TODO(woosuk): Tune NUM_THREADS.
template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
bool IS_FP8_E5M2_KV_CACHE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher(
torch::Tensor& out,
@@ -602,8 +638,8 @@ void paged_attention_v1_launcher(
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
@@ -647,35 +683,35 @@ void paged_attention_v1_launcher(
}
}
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_launcher<T, BLOCK_SIZE>( \
out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
max_context_len, \
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER(T, 8); \
break; \
case 16: \
CALL_V1_LAUNCHER(T, 16); \
break; \
case 32: \
CALL_V1_LAUNCHER(T, 32); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
break; \
case 16: \
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
break; \
case 32: \
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v1(
@@ -689,20 +725,36 @@ void paged_attention_v1(
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8_e5m2") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, \
max_logits_ptr, \
@@ -730,7 +782,9 @@ void paged_attention_v1(
template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
bool IS_FP8_E5M2_KV_CACHE,
int NUM_THREADS = 128,
int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
@@ -768,8 +822,8 @@ void paged_attention_v2_launcher(
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
@@ -816,38 +870,38 @@ void paged_attention_v2_launcher(
}
}
#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_launcher<T, BLOCK_SIZE>( \
out, \
exp_sums, \
max_logits, \
tmp_out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
max_context_len, \
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
out, \
exp_sums, \
max_logits, \
tmp_out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER(T, 8); \
break; \
case 16: \
CALL_V2_LAUNCHER(T, 16); \
break; \
case 32: \
CALL_V2_LAUNCHER(T, 32); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
break; \
case 16: \
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
break; \
case 32: \
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2(
@@ -864,15 +918,30 @@ void paged_attention_v2(
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8_e5m2") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
}