[CPU] Enable FP16 (Half dtype) support for s390x (#34116)
Signed-off-by: Rehan Khan <Rehan.Khan7@ibm.com>
This commit is contained in:
@@ -821,7 +821,7 @@ struct VecTypeTrait<c10::BFloat16> {
|
||||
using vec_t = vec_op::BF16Vec16;
|
||||
};
|
||||
|
||||
#if !defined(__powerpc__) && !defined(__s390x__)
|
||||
#if !defined(__powerpc__)
|
||||
template <>
|
||||
struct VecTypeTrait<c10::Half> {
|
||||
using vec_t = vec_op::FP16Vec16;
|
||||
|
||||
@@ -16,10 +16,12 @@ namespace vec_op {
|
||||
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic
|
||||
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
|
||||
|
||||
// FIXME: FP16 is not fully supported in Torch-CPU
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
// NOTE: FP16 (Half) is supported on s390x via custom bit-manipulation
|
||||
// conversion. PyTorch itself lacks native s390x FP16 support.
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
@@ -86,6 +88,39 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
}
|
||||
};
|
||||
|
||||
struct FP16Vec8 : public Vec<FP16Vec8> {
|
||||
constexpr static int VEC_ELEM_NUM = 8;
|
||||
|
||||
__vector signed short reg;
|
||||
|
||||
explicit FP16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {}
|
||||
explicit FP16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__vector signed short*>(ptr) = reg;
|
||||
}
|
||||
};
|
||||
|
||||
struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
|
||||
ss16x8x2_t reg;
|
||||
|
||||
explicit FP16Vec16(const void* ptr) {
|
||||
// Load 256 bits (16 FP16 values) in two parts
|
||||
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
|
||||
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
|
||||
}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const {
|
||||
// Save 256 bits in two parts
|
||||
vec_xst(reg.val[0], 0, (signed short*)ptr);
|
||||
vec_xst(reg.val[1], 16, (signed short*)ptr);
|
||||
}
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
|
||||
@@ -108,6 +143,92 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
const static __vector signed short zero = vec_splats((signed short)0);
|
||||
|
||||
FORCE_INLINE __vector float fp16_to_fp32_bits(__vector unsigned int x) {
|
||||
const __vector unsigned int mask_sign = {0x8000, 0x8000, 0x8000, 0x8000};
|
||||
const __vector unsigned int mask_exp = {0x7C00, 0x7C00, 0x7C00, 0x7C00};
|
||||
const __vector unsigned int mask_mant = {0x03FF, 0x03FF, 0x03FF, 0x03FF};
|
||||
const __vector unsigned int bias_adj = {112, 112, 112, 112};
|
||||
const __vector unsigned int exp_max_fp16 = {0x1F, 0x1F, 0x1F,
|
||||
0x1F}; // FP16 NaN/Inf exponent
|
||||
const __vector unsigned int exp_max_fp32 = {0xFF, 0xFF, 0xFF,
|
||||
0xFF}; // FP32 NaN/Inf exponent
|
||||
|
||||
__vector unsigned int s = (x & mask_sign) << 16;
|
||||
__vector unsigned int e = (x & mask_exp) >> 10;
|
||||
__vector unsigned int m = (x & mask_mant) << 13;
|
||||
|
||||
// Check for NaN/Inf: exponent = 0x1F in FP16
|
||||
__vector __bool int is_nan_inf = vec_cmpeq(e, exp_max_fp16);
|
||||
|
||||
// Normal: adjust bias; NaN/Inf: set to 0xFF
|
||||
__vector unsigned int e_normal = e + bias_adj;
|
||||
e = vec_sel(e_normal, exp_max_fp32, is_nan_inf);
|
||||
|
||||
return (__vector float)(s | (e << 23) | m);
|
||||
}
|
||||
|
||||
FORCE_INLINE __vector unsigned int fp32_to_fp16_bits(__vector float f_in) {
|
||||
__vector unsigned int in = (__vector unsigned int)f_in;
|
||||
|
||||
const __vector unsigned int mask_sign_32 = {0x80000000, 0x80000000,
|
||||
0x80000000, 0x80000000};
|
||||
const __vector unsigned int mask_exp_32 = {0x7F800000, 0x7F800000, 0x7F800000,
|
||||
0x7F800000};
|
||||
const __vector unsigned int mask_mant_32 = {0x007FFFFF, 0x007FFFFF,
|
||||
0x007FFFFF, 0x007FFFFF};
|
||||
|
||||
// Use SIGNED integers for exponent math to handle underflow check
|
||||
const __vector signed int bias_adj = {112, 112, 112, 112};
|
||||
const __vector signed int zero = {0, 0, 0, 0};
|
||||
const __vector signed int max_exp = {31, 31, 31, 31}; // Max FP16 exp
|
||||
const __vector unsigned int exp_max_fp32 = {0xFF, 0xFF, 0xFF, 0xFF};
|
||||
const __vector unsigned int exp_max_fp16 = {0x1F, 0x1F, 0x1F, 0x1F};
|
||||
|
||||
__vector unsigned int s = (in & mask_sign_32) >> 16;
|
||||
__vector unsigned int e_u = (in & mask_exp_32) >> 23;
|
||||
|
||||
// Check for NaN/Inf: exponent = 0xFF in FP32
|
||||
__vector __bool int is_nan_inf = vec_cmpeq(e_u, exp_max_fp32);
|
||||
|
||||
__vector signed int e_s = (__vector signed int)e_u;
|
||||
e_s = vec_sub(e_s, bias_adj);
|
||||
e_s = vec_max(e_s, zero);
|
||||
e_s = vec_min(e_s, max_exp);
|
||||
__vector unsigned int e_normal = (__vector unsigned int)e_s;
|
||||
|
||||
__vector unsigned int e_final = vec_sel(e_normal, exp_max_fp16, is_nan_inf);
|
||||
|
||||
const __vector unsigned int one_v = {1, 1, 1, 1};
|
||||
const __vector unsigned int mask_sticky = {0xFFF, 0xFFF, 0xFFF, 0xFFF};
|
||||
|
||||
__vector unsigned int round_bit = (in >> 12) & one_v;
|
||||
__vector unsigned int sticky = in & mask_sticky;
|
||||
__vector unsigned int m = (in & mask_mant_32) >> 13;
|
||||
__vector unsigned int lsb = m & one_v; // LSB of mantissa for tie-breaking
|
||||
|
||||
// Round up if: round_bit && (sticky || lsb)
|
||||
__vector __bool int sticky_nonzero =
|
||||
vec_cmpgt(sticky, (__vector unsigned int){0, 0, 0, 0});
|
||||
__vector __bool int lsb_set = vec_cmpeq(lsb, one_v);
|
||||
__vector __bool int round_up =
|
||||
vec_and(vec_cmpeq(round_bit, one_v), vec_or(sticky_nonzero, lsb_set));
|
||||
|
||||
m = vec_sel(m, m + one_v, round_up);
|
||||
|
||||
const __vector unsigned int mant_mask = {0x3FF, 0x3FF, 0x3FF, 0x3FF};
|
||||
const __vector unsigned int max_normal_exp = {0x1E, 0x1E, 0x1E, 0x1E};
|
||||
__vector __bool int mant_overflows = vec_cmpgt(m, mant_mask);
|
||||
__vector __bool int would_overflow_to_inf =
|
||||
vec_and(mant_overflows, vec_cmpeq(e_final, max_normal_exp));
|
||||
__vector unsigned int e_inc = vec_min(e_final + one_v, exp_max_fp16);
|
||||
e_final = vec_sel(e_final, e_inc, mant_overflows);
|
||||
m = vec_and(m, mant_mask);
|
||||
e_final = vec_sel(e_final, max_normal_exp, would_overflow_to_inf);
|
||||
m = vec_sel(m, mant_mask, would_overflow_to_inf);
|
||||
|
||||
return s | (e_final << 10) | m;
|
||||
}
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
constexpr static int VEC_ELEM_NUM = 32;
|
||||
|
||||
@@ -180,6 +301,18 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
reg.val[1] = (__vector float)vec_mergel(v.reg, zero);
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const FP16Vec8& v) {
|
||||
// Cast to UNSIGNED short vector to prevent sign-extension during unpack
|
||||
__vector unsigned short raw_u = (__vector unsigned short)v.reg;
|
||||
|
||||
// Unpack 8x16-bit to two 4x32-bit vectors (Zero extended)
|
||||
__vector unsigned int raw_hi = (__vector unsigned int)vec_unpackh(raw_u);
|
||||
__vector unsigned int raw_lo = (__vector unsigned int)vec_unpackl(raw_u);
|
||||
|
||||
reg.val[0] = fp16_to_fp32_bits(raw_hi);
|
||||
reg.val[1] = fp16_to_fp32_bits(raw_lo);
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
@@ -531,6 +664,22 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg.val[3] = (__vector float)vec_mergel(v.reg.val[1], zero);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16& v) {
|
||||
__vector unsigned int raw_hi_0 =
|
||||
(__vector unsigned int)vec_unpackh(v.reg.val[0]);
|
||||
__vector unsigned int raw_lo_0 =
|
||||
(__vector unsigned int)vec_unpackl(v.reg.val[0]);
|
||||
reg.val[0] = fp16_to_fp32_bits(raw_hi_0);
|
||||
reg.val[1] = fp16_to_fp32_bits(raw_lo_0);
|
||||
|
||||
__vector unsigned int raw_hi_1 =
|
||||
(__vector unsigned int)vec_unpackh(v.reg.val[1]);
|
||||
__vector unsigned int raw_lo_1 =
|
||||
(__vector unsigned int)vec_unpackl(v.reg.val[1]);
|
||||
reg.val[2] = fp16_to_fp32_bits(raw_hi_1);
|
||||
reg.val[3] = fp16_to_fp32_bits(raw_lo_1);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
@@ -628,8 +777,10 @@ struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
|
||||
// On s390x, FP16 (Half) is not natively supported, use FP32 vectors instead
|
||||
using FP16Vec16 = FP32Vec16;
|
||||
template <>
|
||||
struct VecType<c10::Half> {
|
||||
using vec_type = FP16Vec8;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
@@ -650,6 +801,52 @@ inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
*ptr = *(v_ptr + 1);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void storeFP32<::c10::Half>(float v, ::c10::Half* ptr) {
|
||||
// Use bit-manipulation for IEEE FP32 to FP16 conversion since vector
|
||||
// intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can
|
||||
// produce incorrect results for some inputs. Process each of the 4 vectors
|
||||
// separately.
|
||||
uint32_t in;
|
||||
std::memcpy(&in, &v, sizeof(in));
|
||||
|
||||
uint32_t s = (in & 0x80000000) >> 16; // Sign
|
||||
uint32_t e = (in & 0x7F800000) >> 23; // Exponent
|
||||
uint32_t round_bit = (in >> 12) & 1;
|
||||
uint32_t sticky = (in & 0xFFF) != 0; // Any bits in [11..0]
|
||||
uint32_t m = (in & 0x007FFFFF) >> 13;
|
||||
uint32_t lsb = m & 1; // LSB of mantissa for tie-breaking
|
||||
|
||||
// Check for NaN/Inf before rounding
|
||||
bool is_nan_inf = (e == 0xFF);
|
||||
|
||||
if (round_bit && (sticky || lsb)) {
|
||||
m++;
|
||||
// Handle mantissa overflow: if m overflows 10 bits, increment exponent
|
||||
if (m > 0x3FF) {
|
||||
m = 0;
|
||||
e++;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_nan_inf) {
|
||||
// NaN/Inf: preserve it
|
||||
e = 0x1F;
|
||||
} else {
|
||||
// Normal: adjust bias (127 - 15), flush subnormals to zero
|
||||
e = (e >= 112) ? (e - 112) : 0;
|
||||
// If exponent overflows to Inf range, saturate to max normal FP16 value
|
||||
if (e > 0x1E) {
|
||||
e = 0x1E; // Max normal exponent
|
||||
m = 0x3FF; // Max mantissa
|
||||
}
|
||||
}
|
||||
|
||||
uint16_t fp16 = (uint16_t)(s | (e << 10) | m);
|
||||
|
||||
*reinterpret_cast<uint16_t*>(ptr) = fp16;
|
||||
}
|
||||
|
||||
#ifndef __VEC_CLASS_FP_NAN
|
||||
#define __VEC_CLASS_FP_NAN (1 << 6)
|
||||
#endif
|
||||
@@ -803,6 +1000,44 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
|
||||
}
|
||||
|
||||
inline FP16Vec8::FP16Vec8(const FP32Vec8& v) {
|
||||
// Use bit-manipulation for IEEE FP32 to FP16 conversion since vector
|
||||
// intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can
|
||||
// produce incorrect results for some inputs. Process each of the 4 vectors
|
||||
// separately.
|
||||
__vector unsigned int res_hi = fp32_to_fp16_bits(v.reg.val[0]);
|
||||
__vector unsigned int res_lo = fp32_to_fp16_bits(v.reg.val[1]);
|
||||
|
||||
const __vector unsigned char perm_pack = {
|
||||
2, 3, 6, 7, 10, 11, 14, 15, // Select lower 2 bytes from res_hi
|
||||
18, 19, 22, 23, 26, 27, 30, 31 // Select lower 2 bytes from res_lo
|
||||
};
|
||||
|
||||
reg = vec_perm((__vector signed short)res_hi, (__vector signed short)res_lo,
|
||||
perm_pack);
|
||||
}
|
||||
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
|
||||
// Use bit-manipulation for IEEE FP32 to FP16 conversion since vector
|
||||
// intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can
|
||||
// produce incorrect results for some inputs. Process each of the 4 vectors
|
||||
// separately.
|
||||
__vector unsigned int res_0 = fp32_to_fp16_bits(v.reg.val[0]);
|
||||
__vector unsigned int res_1 = fp32_to_fp16_bits(v.reg.val[1]);
|
||||
__vector unsigned int res_2 = fp32_to_fp16_bits(v.reg.val[2]);
|
||||
__vector unsigned int res_3 = fp32_to_fp16_bits(v.reg.val[3]);
|
||||
|
||||
const __vector unsigned char perm_pack = {
|
||||
2, 3, 6, 7, 10, 11, 14, 15, // Lower 2 bytes from first vector
|
||||
18, 19, 22, 23, 26, 27, 30, 31 // Lower 2 bytes from second vector
|
||||
};
|
||||
|
||||
reg.val[0] = vec_perm((__vector signed short)res_0,
|
||||
(__vector signed short)res_1, perm_pack);
|
||||
reg.val[1] = vec_perm((__vector signed short)res_2,
|
||||
(__vector signed short)res_3, perm_pack);
|
||||
}
|
||||
|
||||
// 1D softmax over `n` elements in `input`, writes result to `output`.
|
||||
// Uses FP32Vec8 for main body, scalar tail handling.
|
||||
// Requirement: n > 0
|
||||
|
||||
@@ -18,8 +18,8 @@ struct KernelVecType<float> {
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power and s390x architecture-specific vector types
|
||||
#if defined(__powerpc64__)
|
||||
// Power specific vector types
|
||||
using qk_load_vec_type = vec_op::FP32Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
|
||||
Reference in New Issue
Block a user