58
csrc/quantization/w8a8/fp8/common.cuh
Normal file
58
csrc/quantization/w8a8/fp8/common.cuh
Normal file
@@ -0,0 +1,58 @@
|
||||
#pragma once
|
||||
|
||||
#include "quantization/vectorization.cuh"
|
||||
#include "quantization/utils.cuh"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include "amd/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
// Determines the preferred FP8 type for the current platform.
|
||||
// Note that for CUDA this just returns true,
|
||||
// but on ROCm it will check device props.
|
||||
static bool is_fp8_ocp() {
|
||||
#ifndef USE_ROCM
|
||||
return true;
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
size_t substring = device_arch.find("gfx94");
|
||||
return substring == std::string::npos;
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
float old;
|
||||
old = (value >= 0)
|
||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(
|
||||
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
|
||||
return old;
|
||||
}
|
||||
|
||||
template <bool is_scale_inverted, typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
x = val * scale;
|
||||
} else {
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r =
|
||||
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<fp8_type>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return fp8::cvt_c10<fp8_type>(r);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
Reference in New Issue
Block a user