[Feature][ROCm]Enable fusion pass for torch.compile on ROCm (#15050)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
@@ -1,20 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "quantization/vectorization.cuh"
|
||||
#include "quantization/utils.cuh"
|
||||
|
||||
#include <cmath>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
|
||||
#else
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#ifdef USE_ROCM
|
||||
#include "amd/quant_utils.cuh"
|
||||
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
|
||||
#define MAYBE_HOST_DEVICE
|
||||
#endif
|
||||
|
||||
// Determines the preferred FP8 type for the current platform.
|
||||
@@ -31,29 +23,6 @@ static bool is_fp8_ocp() {
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct fp8_e4m3_adjusted_max;
|
||||
|
||||
template <>
|
||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fn> {
|
||||
static constexpr c10::Float8_e4m3fn val() {
|
||||
return std::numeric_limits<c10::Float8_e4m3fn>::max();
|
||||
}
|
||||
};
|
||||
|
||||
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
|
||||
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
|
||||
template <>
|
||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fnuz> {
|
||||
static constexpr c10::Float8_e4m3fnuz val() {
|
||||
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v =
|
||||
fp8_e4m3_adjusted_max<T>::val();
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
@@ -76,8 +45,8 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
||||
float r =
|
||||
fmax(-quant_type_max_v<fp8_type>, fmin(x, quant_type_max_v<fp8_type>));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<fp8_type>(r);
|
||||
#else
|
||||
@@ -123,7 +92,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v<fp8_type>);
|
||||
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user