[ Kernel ] FP8 Dynamic Per Token Quant - Add scale_ub (#6593)

Co-authored-by: Varun Sundar Rabindranth <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2024-07-19 21:15:26 -04:00
committed by GitHub
parent e81522e879
commit 2e26564259
6 changed files with 86 additions and 39 deletions

View File

@@ -23,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
template <typename scalar_t>
template <bool is_scale_inverted>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
const scalar_t val, const float inverted_scale) {
float x = static_cast<float>(val) * inverted_scale;
float const val, float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<c10::Float8_e4m3fn>(r);
}
@@ -117,10 +123,10 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
return absmax_val;
}
template <typename scalar_t>
template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
scalar_t const* __restrict__ input,
float const inverted_scale,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
@@ -135,16 +141,21 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}
@@ -158,15 +169,17 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);
scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid,
blockDim.x * gridDim.x);
scaled_fp8_conversion_vec<scalar_t, true>(
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
}
template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, const int hidden_size) {
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
@@ -188,20 +201,27 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
}
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val;
__shared__ float token_scale;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / FP8_E4M3_MAX;
if (scale_ub) {
token_scale = min(block_absmax_val_maybe, *scale_ub);
} else {
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
scale[token_idx] = token_scale;
}
__syncthreads();
float const inverted_scale = FP8_E4M3_MAX / block_absmax_val;
// Note that we don't use inverted scales so we can match FBGemm impl.
if (can_vectorize) {
scaled_fp8_conversion_vec(token_output, token_input, inverted_scale,
hidden_size, tid, blockDim.x);
scaled_fp8_conversion_vec<scalar_t, false>(
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale);
token_output[i] = scaled_fp8_conversion<false>(
static_cast<float>(token_input[i]), token_scale);
}
}
}
@@ -246,9 +266,10 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
});
}
void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales) {
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
@@ -264,6 +285,8 @@ void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(), hidden_size);
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size);
});
}