diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp index 89cf2dc3a..fbe0e8778 100644 --- a/csrc/cpu/cpu_attn_impl.hpp +++ b/csrc/cpu/cpu_attn_impl.hpp @@ -821,7 +821,7 @@ struct VecTypeTrait { using vec_t = vec_op::BF16Vec16; }; -#if !defined(__powerpc__) && !defined(__s390x__) +#if !defined(__powerpc__) template <> struct VecTypeTrait { using vec_t = vec_op::FP16Vec16; diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp index 9efd8b7ec..700ba0306 100644 --- a/csrc/cpu/cpu_types_vxe.hpp +++ b/csrc/cpu/cpu_types_vxe.hpp @@ -16,10 +16,12 @@ namespace vec_op { #define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic #define vec_sl(a, b) ((a) << (b)) // Vector Shift Left -// FIXME: FP16 is not fully supported in Torch-CPU -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) +// NOTE: FP16 (Half) is supported on s390x via custom bit-manipulation +// conversion. PyTorch itself lacks native s390x FP16 support. +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) @@ -86,6 +88,39 @@ struct BF16Vec8 : public Vec { } }; +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __vector signed short reg; + + explicit FP16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {} + explicit FP16Vec8(const FP32Vec8&); + + void save(void* ptr) const { + *reinterpret_cast<__vector signed short*>(ptr) = reg; + } +}; + +struct FP16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + ss16x8x2_t reg; + + explicit FP16Vec16(const void* ptr) { + // Load 256 bits (16 FP16 values) in two parts + reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr); + reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr); + } + + explicit FP16Vec16(const FP32Vec16&); + + void save(void* ptr) const { + // Save 256 bits in two parts + vec_xst(reg.val[0], 0, (signed short*)ptr); + vec_xst(reg.val[1], 16, (signed short*)ptr); + } +}; + struct BF16Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; @@ -108,6 +143,92 @@ struct BF16Vec16 : public Vec { const static __vector signed short zero = vec_splats((signed short)0); +FORCE_INLINE __vector float fp16_to_fp32_bits(__vector unsigned int x) { + const __vector unsigned int mask_sign = {0x8000, 0x8000, 0x8000, 0x8000}; + const __vector unsigned int mask_exp = {0x7C00, 0x7C00, 0x7C00, 0x7C00}; + const __vector unsigned int mask_mant = {0x03FF, 0x03FF, 0x03FF, 0x03FF}; + const __vector unsigned int bias_adj = {112, 112, 112, 112}; + const __vector unsigned int exp_max_fp16 = {0x1F, 0x1F, 0x1F, + 0x1F}; // FP16 NaN/Inf exponent + const __vector unsigned int exp_max_fp32 = {0xFF, 0xFF, 0xFF, + 0xFF}; // FP32 NaN/Inf exponent + + __vector unsigned int s = (x & mask_sign) << 16; + __vector unsigned int e = (x & mask_exp) >> 10; + __vector unsigned int m = (x & mask_mant) << 13; + + // Check for NaN/Inf: exponent = 0x1F in FP16 + __vector __bool int is_nan_inf = vec_cmpeq(e, exp_max_fp16); + + // Normal: adjust bias; NaN/Inf: set to 0xFF + __vector unsigned int e_normal = e + bias_adj; + e = vec_sel(e_normal, exp_max_fp32, is_nan_inf); + + return (__vector float)(s | (e << 23) | m); +} + +FORCE_INLINE __vector unsigned int fp32_to_fp16_bits(__vector float f_in) { + __vector unsigned int in = (__vector unsigned int)f_in; + + const __vector unsigned int mask_sign_32 = {0x80000000, 0x80000000, + 0x80000000, 0x80000000}; + const __vector unsigned int mask_exp_32 = {0x7F800000, 0x7F800000, 0x7F800000, + 0x7F800000}; + const __vector unsigned int mask_mant_32 = {0x007FFFFF, 0x007FFFFF, + 0x007FFFFF, 0x007FFFFF}; + + // Use SIGNED integers for exponent math to handle underflow check + const __vector signed int bias_adj = {112, 112, 112, 112}; + const __vector signed int zero = {0, 0, 0, 0}; + const __vector signed int max_exp = {31, 31, 31, 31}; // Max FP16 exp + const __vector unsigned int exp_max_fp32 = {0xFF, 0xFF, 0xFF, 0xFF}; + const __vector unsigned int exp_max_fp16 = {0x1F, 0x1F, 0x1F, 0x1F}; + + __vector unsigned int s = (in & mask_sign_32) >> 16; + __vector unsigned int e_u = (in & mask_exp_32) >> 23; + + // Check for NaN/Inf: exponent = 0xFF in FP32 + __vector __bool int is_nan_inf = vec_cmpeq(e_u, exp_max_fp32); + + __vector signed int e_s = (__vector signed int)e_u; + e_s = vec_sub(e_s, bias_adj); + e_s = vec_max(e_s, zero); + e_s = vec_min(e_s, max_exp); + __vector unsigned int e_normal = (__vector unsigned int)e_s; + + __vector unsigned int e_final = vec_sel(e_normal, exp_max_fp16, is_nan_inf); + + const __vector unsigned int one_v = {1, 1, 1, 1}; + const __vector unsigned int mask_sticky = {0xFFF, 0xFFF, 0xFFF, 0xFFF}; + + __vector unsigned int round_bit = (in >> 12) & one_v; + __vector unsigned int sticky = in & mask_sticky; + __vector unsigned int m = (in & mask_mant_32) >> 13; + __vector unsigned int lsb = m & one_v; // LSB of mantissa for tie-breaking + + // Round up if: round_bit && (sticky || lsb) + __vector __bool int sticky_nonzero = + vec_cmpgt(sticky, (__vector unsigned int){0, 0, 0, 0}); + __vector __bool int lsb_set = vec_cmpeq(lsb, one_v); + __vector __bool int round_up = + vec_and(vec_cmpeq(round_bit, one_v), vec_or(sticky_nonzero, lsb_set)); + + m = vec_sel(m, m + one_v, round_up); + + const __vector unsigned int mant_mask = {0x3FF, 0x3FF, 0x3FF, 0x3FF}; + const __vector unsigned int max_normal_exp = {0x1E, 0x1E, 0x1E, 0x1E}; + __vector __bool int mant_overflows = vec_cmpgt(m, mant_mask); + __vector __bool int would_overflow_to_inf = + vec_and(mant_overflows, vec_cmpeq(e_final, max_normal_exp)); + __vector unsigned int e_inc = vec_min(e_final + one_v, exp_max_fp16); + e_final = vec_sel(e_final, e_inc, mant_overflows); + m = vec_and(m, mant_mask); + e_final = vec_sel(e_final, max_normal_exp, would_overflow_to_inf); + m = vec_sel(m, mant_mask, would_overflow_to_inf); + + return s | (e_final << 10) | m; +} + struct BF16Vec32 : public Vec { constexpr static int VEC_ELEM_NUM = 32; @@ -180,6 +301,18 @@ struct FP32Vec8 : public Vec { reg.val[1] = (__vector float)vec_mergel(v.reg, zero); } + explicit FP32Vec8(const FP16Vec8& v) { + // Cast to UNSIGNED short vector to prevent sign-extension during unpack + __vector unsigned short raw_u = (__vector unsigned short)v.reg; + + // Unpack 8x16-bit to two 4x32-bit vectors (Zero extended) + __vector unsigned int raw_hi = (__vector unsigned int)vec_unpackh(raw_u); + __vector unsigned int raw_lo = (__vector unsigned int)vec_unpackl(raw_u); + + reg.val[0] = fp16_to_fp32_bits(raw_hi); + reg.val[1] = fp16_to_fp32_bits(raw_lo); + } + float reduce_sum() const { AliasReg ar; ar.reg = reg; @@ -531,6 +664,22 @@ struct FP32Vec16 : public Vec { reg.val[3] = (__vector float)vec_mergel(v.reg.val[1], zero); } + explicit FP32Vec16(const FP16Vec16& v) { + __vector unsigned int raw_hi_0 = + (__vector unsigned int)vec_unpackh(v.reg.val[0]); + __vector unsigned int raw_lo_0 = + (__vector unsigned int)vec_unpackl(v.reg.val[0]); + reg.val[0] = fp16_to_fp32_bits(raw_hi_0); + reg.val[1] = fp16_to_fp32_bits(raw_lo_0); + + __vector unsigned int raw_hi_1 = + (__vector unsigned int)vec_unpackh(v.reg.val[1]); + __vector unsigned int raw_lo_1 = + (__vector unsigned int)vec_unpackl(v.reg.val[1]); + reg.val[2] = fp16_to_fp32_bits(raw_hi_1); + reg.val[3] = fp16_to_fp32_bits(raw_lo_1); + } + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} FP32Vec16 operator*(const FP32Vec16& b) const { @@ -628,8 +777,10 @@ struct VecType { using vec_type = BF16Vec8; }; -// On s390x, FP16 (Half) is not natively supported, use FP32 vectors instead -using FP16Vec16 = FP32Vec16; +template <> +struct VecType { + using vec_type = FP16Vec8; +}; template void storeFP32(float v, T* ptr) { @@ -650,6 +801,52 @@ inline void storeFP32(float v, c10::BFloat16* ptr) { *ptr = *(v_ptr + 1); } +template <> +inline void storeFP32<::c10::Half>(float v, ::c10::Half* ptr) { + // Use bit-manipulation for IEEE FP32 to FP16 conversion since vector + // intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can + // produce incorrect results for some inputs. Process each of the 4 vectors + // separately. + uint32_t in; + std::memcpy(&in, &v, sizeof(in)); + + uint32_t s = (in & 0x80000000) >> 16; // Sign + uint32_t e = (in & 0x7F800000) >> 23; // Exponent + uint32_t round_bit = (in >> 12) & 1; + uint32_t sticky = (in & 0xFFF) != 0; // Any bits in [11..0] + uint32_t m = (in & 0x007FFFFF) >> 13; + uint32_t lsb = m & 1; // LSB of mantissa for tie-breaking + + // Check for NaN/Inf before rounding + bool is_nan_inf = (e == 0xFF); + + if (round_bit && (sticky || lsb)) { + m++; + // Handle mantissa overflow: if m overflows 10 bits, increment exponent + if (m > 0x3FF) { + m = 0; + e++; + } + } + + if (is_nan_inf) { + // NaN/Inf: preserve it + e = 0x1F; + } else { + // Normal: adjust bias (127 - 15), flush subnormals to zero + e = (e >= 112) ? (e - 112) : 0; + // If exponent overflows to Inf range, saturate to max normal FP16 value + if (e > 0x1E) { + e = 0x1E; // Max normal exponent + m = 0x3FF; // Max mantissa + } + } + + uint16_t fp16 = (uint16_t)(s | (e << 10) | m); + + *reinterpret_cast(ptr) = fp16; +} + #ifndef __VEC_CLASS_FP_NAN #define __VEC_CLASS_FP_NAN (1 << 6) #endif @@ -803,6 +1000,44 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask); } +inline FP16Vec8::FP16Vec8(const FP32Vec8& v) { + // Use bit-manipulation for IEEE FP32 to FP16 conversion since vector + // intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can + // produce incorrect results for some inputs. Process each of the 4 vectors + // separately. + __vector unsigned int res_hi = fp32_to_fp16_bits(v.reg.val[0]); + __vector unsigned int res_lo = fp32_to_fp16_bits(v.reg.val[1]); + + const __vector unsigned char perm_pack = { + 2, 3, 6, 7, 10, 11, 14, 15, // Select lower 2 bytes from res_hi + 18, 19, 22, 23, 26, 27, 30, 31 // Select lower 2 bytes from res_lo + }; + + reg = vec_perm((__vector signed short)res_hi, (__vector signed short)res_lo, + perm_pack); +} + +inline FP16Vec16::FP16Vec16(const FP32Vec16& v) { + // Use bit-manipulation for IEEE FP32 to FP16 conversion since vector + // intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can + // produce incorrect results for some inputs. Process each of the 4 vectors + // separately. + __vector unsigned int res_0 = fp32_to_fp16_bits(v.reg.val[0]); + __vector unsigned int res_1 = fp32_to_fp16_bits(v.reg.val[1]); + __vector unsigned int res_2 = fp32_to_fp16_bits(v.reg.val[2]); + __vector unsigned int res_3 = fp32_to_fp16_bits(v.reg.val[3]); + + const __vector unsigned char perm_pack = { + 2, 3, 6, 7, 10, 11, 14, 15, // Lower 2 bytes from first vector + 18, 19, 22, 23, 26, 27, 30, 31 // Lower 2 bytes from second vector + }; + + reg.val[0] = vec_perm((__vector signed short)res_0, + (__vector signed short)res_1, perm_pack); + reg.val[1] = vec_perm((__vector signed short)res_2, + (__vector signed short)res_3, perm_pack); +} + // 1D softmax over `n` elements in `input`, writes result to `output`. // Uses FP32Vec8 for main body, scalar tail handling. // Requirement: n > 0 diff --git a/csrc/cpu/mla_decode.cpp b/csrc/cpu/mla_decode.cpp index 564055ef5..582c480c3 100644 --- a/csrc/cpu/mla_decode.cpp +++ b/csrc/cpu/mla_decode.cpp @@ -18,8 +18,8 @@ struct KernelVecType { template <> struct KernelVecType { -#if defined(__powerpc64__) || defined(__s390x__) - // Power and s390x architecture-specific vector types +#if defined(__powerpc64__) + // Power specific vector types using qk_load_vec_type = vec_op::FP32Vec16; using qk_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::FP32Vec16;