[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535)

This commit is contained in:
Cody Yu
2024-05-09 17:04:17 -07:00
committed by GitHub
parent 379da6dcb5
commit c833101740
17 changed files with 843 additions and 558 deletions

View File

@@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"Failed to determine torch nvcc compiler flags")
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
list(APPEND GPU_FLAGS "-DENABLE_FP8")
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
list(REMOVE_ITEM GPU_FLAGS
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DENABLE_FP8_E4M3"
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")