Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) (#3290)
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: HaiShaw <hixiao@gmail.com> Co-authored-by: AdrianAbeyta <Adrian.Abeyta@amd.com> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: root <root@gt-pla-u18-08.pla.dcgpu> Co-authored-by: mawong-amd <156021403+mawong-amd@users.noreply.github.com> Co-authored-by: ttbachyinsda <ttbachyinsda@outlook.com> Co-authored-by: guofangze <guofangze@kuaishou.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -4,8 +4,10 @@
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
#include "quantization/fp8/amd_detail/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
@@ -151,7 +153,7 @@ void copy_blocks(
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
||||
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
@@ -163,7 +165,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int x) {
|
||||
const int x,
|
||||
const float kv_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx < 0) {
|
||||
@@ -195,10 +198,13 @@ __global__ void reshape_and_cache_kernel(
|
||||
+ block_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (is_fp8_e5m2_kv_cache) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
if constexpr (is_fp8_kv_cache) {
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
||||
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
|
||||
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
@@ -211,8 +217,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
@@ -223,7 +229,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
num_heads, \
|
||||
head_size, \
|
||||
block_size, \
|
||||
x);
|
||||
x, \
|
||||
kv_scale);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
@@ -231,7 +238,8 @@ void reshape_and_cache(
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype)
|
||||
const std::string& kv_cache_dtype,
|
||||
const float kv_scale)
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
@@ -254,7 +262,7 @@ void reshape_and_cache(
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
} else if (kv_cache_dtype == "fp8") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
@@ -270,15 +278,17 @@ void reshape_and_cache(
|
||||
namespace vllm {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
__global__ void convert_fp8_e5m2_kernel(
|
||||
__global__ void convert_fp8_kernel(
|
||||
const Tin* __restrict__ src_cache,
|
||||
Tout* __restrict__ dst_cache,
|
||||
const int64_t block_stride) {
|
||||
const int64_t block_idx = blockIdx.x;
|
||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||
int64_t idx = block_idx * block_stride + i;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
@@ -287,16 +297,25 @@ __global__ void convert_fp8_e5m2_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
||||
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||
#define CALL_CONVERT_FP8(Tout, Tin) \
|
||||
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||
block_stride);
|
||||
|
||||
void convert_fp8_e5m2(
|
||||
void convert_fp8(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache)
|
||||
{
|
||||
torch::Device src_device = src_cache.device();
|
||||
torch::Device dst_device = dst_cache.device();
|
||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
||||
TORCH_CHECK(
|
||||
src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
||||
|
||||
int64_t num_blocks = src_cache.size(0);
|
||||
int64_t block_stride = src_cache.stride(0);
|
||||
|
||||
@@ -305,16 +324,16 @@ void convert_fp8_e5m2(
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
||||
CALL_CONVERT_FP8(uint8_t, float);
|
||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
||||
CALL_CONVERT_FP8(uint8_t, uint16_t);
|
||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
||||
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
||||
CALL_CONVERT_FP8(float, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
||||
CALL_CONVERT_FP8(uint16_t, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
||||
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user