diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp index 12c6f5d30..98f55d7c0 100644 --- a/csrc/cpu/cpu_attn_impl.hpp +++ b/csrc/cpu/cpu_attn_impl.hpp @@ -847,7 +847,7 @@ struct VecTypeTrait { }; #endif -#if !defined(__powerpc__) +#if !defined(__powerpc__) && !defined(__s390x__) template <> struct VecTypeTrait { using vec_t = vec_op::FP16Vec16; diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp index 51bca37e6..9efd8b7ec 100644 --- a/csrc/cpu/cpu_types_vxe.hpp +++ b/csrc/cpu/cpu_types_vxe.hpp @@ -4,6 +4,7 @@ #include #include +#include #include namespace vec_op { @@ -174,8 +175,9 @@ struct FP32Vec8 : public Vec { } explicit FP32Vec8(const BF16Vec8& v) { - reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); - reg.val[1] = (__vector float)vec_mergel(zero, v.reg); + // On big-endian s390x, place BF16 first to get correct byte order + reg.val[0] = (__vector float)vec_mergeh(v.reg, zero); + reg.val[1] = (__vector float)vec_mergel(v.reg, zero); } float reduce_sum() const { @@ -189,51 +191,257 @@ struct FP32Vec8 : public Vec { } FP32Vec8 exp() const { - // TODO: Vectorize this - AliasReg ar; - ar.reg = reg; - f32x4x4_t ret; - ret.val[0][0] = std::exp(ar.values[0]); - ret.val[0][1] = std::exp(ar.values[1]); - ret.val[0][2] = std::exp(ar.values[2]); - ret.val[0][3] = std::exp(ar.values[3]); - ret.val[1][0] = std::exp(ar.values[4]); - ret.val[1][1] = std::exp(ar.values[5]); - ret.val[1][2] = std::exp(ar.values[6]); - ret.val[1][3] = std::exp(ar.values[7]); - return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + f32x4x2_t out; + + const __vector float log2e = vec_splats(1.44269504088896341f); + const __vector float one = vec_splats(1.0f); + const __vector float min_x = vec_splats(-87.3f); + const __vector float max_x = vec_splats(88.7f); + + // 5th-degree minimax polynomial for 2^r (r in [0,1)) + const __vector float c1 = vec_splats(0.6931471805599453f); + const __vector float c2 = vec_splats(0.240226506959101f); + const __vector float c3 = vec_splats(0.05550410866482158f); + const __vector float c4 = vec_splats(0.009618129107628477f); + const __vector float c5 = vec_splats(0.0013333558146428443f); + + for (int i = 0; i < 2; i++) { + __vector float x = reg.val[i]; + + x = vec_max(x, min_x); + x = vec_min(x, max_x); + + __vector float y = vec_mul(x, log2e); + + __vector float kf = vec_floor(y); + __vector float r = vec_sub(y, kf); + + __vector signed int k = vec_signed(kf); + const __vector signed int min_k = vec_splats((signed int)-126); + const __vector signed int max_k = vec_splats((signed int)127); + k = vec_min(vec_max(k, min_k), max_k); + + // Build 2^k from exponent bits + __vector signed int exp_int = vec_add(k, vec_splats((signed int)127)); + __vector unsigned int bits = (__vector unsigned int)exp_int; + bits = vec_sl(bits, vec_splats((unsigned int)23)); + __vector float pow2k = (__vector float)bits; + + // Improved minimax polynomial + __vector float poly = vec_madd(c5, r, c4); + poly = vec_madd(poly, r, c3); + poly = vec_madd(poly, r, c2); + poly = vec_madd(poly, r, c1); + poly = vec_madd(poly, r, one); + + out.val[i] = vec_mul(pow2k, poly); + } + + return FP32Vec8(out); } FP32Vec8 tanh() const { - // TODO: Vectorize this - AliasReg ar; - ar.reg = reg; - f32x4x4_t ret; - ret.val[0][0] = std::tanh(ar.values[0]); - ret.val[0][1] = std::tanh(ar.values[1]); - ret.val[0][2] = std::tanh(ar.values[2]); - ret.val[0][3] = std::tanh(ar.values[3]); - ret.val[1][0] = std::tanh(ar.values[4]); - ret.val[1][1] = std::tanh(ar.values[5]); - ret.val[1][2] = std::tanh(ar.values[6]); - ret.val[1][3] = std::tanh(ar.values[7]); - return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) + const __vector float one = vec_splats(1.0f); + const __vector float two = vec_splats(2.0f); + const __vector float zero = vec_splats(0.0f); + const __vector float sat = + vec_splats(9.0f); // beyond this, tanh(x) ~ sign(x) + + f32x4x2_t out; + + for (int i = 0; i < 2; i++) { + __vector float x = reg.val[i]; + __vector float ax = vec_abs(x); + + // sign(x): +1 or -1 + __vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero)); + + // saturation mask: |x| > sat + __vector __bool int saturated = vec_cmpgt(ax, sat); + + // 2x + __vector float two_x = vec_mul(x, two); + + // Build a temporary FP32Vec8 with both lanes = 2x, reuse exp() + f32x4x2_t tmp; + tmp.val[0] = two_x; + tmp.val[1] = two_x; + FP32Vec8 exp_2x_vec(tmp); + + FP32Vec8 e2x = exp_2x_vec.exp(); + __vector float e = e2x.reg.val[i]; + + // tanh(x) = (e - 1) / (e + 1) + __vector float num = vec_sub(e, one); + __vector float den = vec_add(e, one); + + __vector float t = vec_div(num, den); + + // For large |x|, clamp to sign(x) + out.val[i] = vec_sel(t, sign, saturated); + } + + return FP32Vec8(out); } FP32Vec8 er() const { - // TODO: Vectorize this - AliasReg ar; - ar.reg = reg; - f32x4x4_t ret; - ret.val[0][0] = std::erf(ar.values[0]); - ret.val[0][1] = std::erf(ar.values[1]); - ret.val[0][2] = std::erf(ar.values[2]); - ret.val[0][3] = std::erf(ar.values[3]); - ret.val[1][0] = std::erf(ar.values[4]); - ret.val[1][1] = std::erf(ar.values[5]); - ret.val[1][2] = std::erf(ar.values[6]); - ret.val[1][3] = std::erf(ar.values[7]); - return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + // A&S 7.1.26 approximation: + // erf(x) = sign(x) * (1 - ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t * + // exp(-x^2)) t = 1 / (1 + p*|x|), p = 0.3275911 + + const __vector float one = vec_splats(1.0f); + const __vector float zero = vec_splats(0.0f); + const __vector float p = vec_splats(0.3275911f); + + // Polynomial coeffs + const __vector float a1 = vec_splats(0.254829592f); + const __vector float a2 = vec_splats(-0.284496736f); + const __vector float a3 = vec_splats(1.421413741f); + const __vector float a4 = vec_splats(-1.453152027f); + const __vector float a5 = vec_splats(1.061405429f); + + // Threshold where erf(x) ~ sign(x) + const __vector float sat = vec_splats(6.0f); + + f32x4x2_t out; + + for (int lane = 0; lane < 2; lane++) { + __vector float x = reg.val[lane]; + __vector float ax = vec_abs(x); + + // sign(x) + __vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero)); + + // |x| > 6 → erf(x) = ±1 + __vector __bool int saturated = vec_cmpgt(ax, sat); + + // t = 1 / (1 + p * |x|) + __vector float t = vec_madd(p, ax, one); + t = vec_div(one, t); + + // poly = a5 + __vector float poly = a5; + poly = vec_madd(poly, t, a4); + poly = vec_madd(poly, t, a3); + poly = vec_madd(poly, t, a2); + poly = vec_madd(poly, t, a1); + + // full polynomial: poly = poly * t + poly = vec_mul(poly, t); + + // Compute exp(-x^2) + __vector float x2 = vec_mul(x, x); + __vector float neg_x2 = vec_neg(x2); + + f32x4x2_t tmp; + tmp.val[0] = neg_x2; + tmp.val[1] = neg_x2; + FP32Vec8 exp_neg_x2(tmp); + + FP32Vec8 e = exp_neg_x2.exp(); + __vector float ex = e.reg.val[lane]; + + // erf(x) = sign * (1 - poly * exp(-x^2)) + __vector float term = vec_mul(poly, ex); + __vector float y = vec_sub(one, term); + y = vec_mul(y, sign); + + // saturated → ±1 + __vector float sat_val = vec_mul(sign, one); + out.val[lane] = vec_sel(y, sat_val, saturated); + } + + return FP32Vec8(out); + } + // Elementwise sigmoid(x) = 1 / (1 + exp(-x)) + FP32Vec8 sigmoid() const { + const __vector float one = vec_splats(1.0f); + + f32x4x2_t neg; + for (int i = 0; i < 2; ++i) { + neg.val[i] = vec_neg(reg.val[i]); + } + + FP32Vec8 neg_x(neg); + FP32Vec8 e = neg_x.exp(); // exp(-x) + + f32x4x2_t denom; + for (int i = 0; i < 2; ++i) { + denom.val[i] = vec_add(one, e.reg.val[i]); + } + + FP32Vec8 denom_vec(denom); + FP32Vec8 one_vec(1.0f); + + return one_vec / denom_vec; + } + + // Tanh-based GELU: + // gelu(x) = 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3))) + FP32Vec8 gelu_tanh() const { + const __vector float k_s2pi = vec_splats(0.7978845608028654f); // √(2/π) + const __vector float k_0_0447 = vec_splats(0.044715f); + + f32x4x2_t x2, x3, inner; + for (int i = 0; i < 2; ++i) { + __vector float x = reg.val[i]; + x2.val[i] = vec_mul(x, x); // x^2 + x3.val[i] = vec_mul(x2.val[i], x); // x^3 + __vector float t = vec_madd(k_0_0447, x3.val[i], x); // x + 0.044715*x^3 + inner.val[i] = vec_mul(k_s2pi, t); // √(2/π)*(...) + } + + FP32Vec8 inner_vec(inner); + FP32Vec8 t = inner_vec.tanh(); // tanh part + + FP32Vec8 one_vec(1.0f); + FP32Vec8 half_vec(0.5f); + + FP32Vec8 x_vec(*this); + return x_vec * half_vec * (one_vec + t); + } + + // Erf-based GELU: + // gelu(x) = 0.5 * x * (1 + erf(x / √2)) + FP32Vec8 gelu_erf() const { + const __vector float inv_sqrt2 = vec_splats(0.7071067811865476f); // 1/√2 + FP32Vec8 x_vec(*this); + + f32x4x2_t scaled; + for (int i = 0; i < 2; ++i) { + scaled.val[i] = vec_mul(reg.val[i], inv_sqrt2); + } + FP32Vec8 x_scaled(scaled); + + FP32Vec8 erf_x = x_scaled.er(); + + FP32Vec8 one_vec(1.0f); + FP32Vec8 half_vec(0.5f); + + return x_vec * half_vec * (one_vec + erf_x); + } + + // Elementwise reciprocal: 1/x (scalar per lane, for correctness) + FP32Vec8 rcp() const { + AliasReg in, out; + in.reg = reg; + + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + out.values[i] = 1.0f / in.values[i]; + } + return FP32Vec8(out.reg); + } + + // Elementwise rsqrt(x) = 1 / sqrt(x) (scalar per lane, for correctness) + FP32Vec8 rsqrt() const { + AliasReg in, out; + in.reg = reg; + + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + out.values[i] = 1.0f / std::sqrt(in.values[i]); + } + return FP32Vec8(out.reg); } FP32Vec8 operator*(const FP32Vec8& b) const { @@ -316,10 +524,11 @@ struct FP32Vec16 : public Vec { } explicit FP32Vec16(const BF16Vec16& v) { - reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); - reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); - reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); - reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); + // On big-endian s390x, place BF16 first to get correct byte order + reg.val[0] = (__vector float)vec_mergeh(v.reg.val[0], zero); + reg.val[1] = (__vector float)vec_mergel(v.reg.val[0], zero); + reg.val[2] = (__vector float)vec_mergeh(v.reg.val[1], zero); + reg.val[3] = (__vector float)vec_mergel(v.reg.val[1], zero); } explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} @@ -376,6 +585,23 @@ struct FP32Vec16 : public Vec { return result; } + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]), + vec_max(reg.val[1], b.reg.val[1]), + vec_max(reg.val[2], b.reg.val[2]), + vec_max(reg.val[3], b.reg.val[3])})); + } + + float reduce_max() const { + AliasReg ar; + ar.reg = reg; + float result = ar.values[0]; + unroll_loop([&result, &ar](int i) { + if (ar.values[i] > result) result = ar.values[i]; + }); + return result; + } + void save(float* ptr) const { vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[1], 16, ptr); @@ -402,15 +628,14 @@ struct VecType { using vec_type = BF16Vec8; }; +// On s390x, FP16 (Half) is not natively supported, use FP32 vectors instead +using FP16Vec16 = FP32Vec16; + template void storeFP32(float v, T* ptr) { *ptr = v; } -inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { - acc = acc + a * b; -} - namespace c10 { struct BFloat16 { uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit @@ -429,6 +654,79 @@ inline void storeFP32(float v, c10::BFloat16* ptr) { #define __VEC_CLASS_FP_NAN (1 << 6) #endif +// Optimized FMA (Fused Multiply-Add) implementations using IBM Z vector +// intrinsics + +// FP32Vec4 FMA: acc = acc + (a * b) or equivalently acc = fma(a, b, acc) +FORCE_INLINE void fma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) { + acc.reg = vec_madd(a.reg, b.reg, acc.reg); +} + +// FP32Vec8 FMA: acc = acc + (a * b) +FORCE_INLINE void fma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) { + acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]); + acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]); +} + +// FP32Vec16 FMA: acc = acc + (a * b) +FORCE_INLINE void fma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) { + acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]); + acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]); + acc.reg.val[2] = vec_madd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]); + acc.reg.val[3] = vec_madd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]); +} + +// Multiply-Subtract: acc = acc - (a * b) +FORCE_INLINE void fms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) { + acc.reg = vec_msub(a.reg, b.reg, acc.reg); +} + +FORCE_INLINE void fms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) { + acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]); + acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]); +} + +FORCE_INLINE void fms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) { + acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]); + acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]); + acc.reg.val[2] = vec_msub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]); + acc.reg.val[3] = vec_msub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]); +} + +// Negative Multiply-Add: acc = -(a * b) + acc +FORCE_INLINE void nfma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) { + acc.reg = vec_nmadd(a.reg, b.reg, acc.reg); +} + +FORCE_INLINE void nfma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) { + acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]); + acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]); +} + +FORCE_INLINE void nfma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) { + acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]); + acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]); + acc.reg.val[2] = vec_nmadd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]); + acc.reg.val[3] = vec_nmadd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]); +} + +// Negative Multiply-Subtract: acc = -(a * b) - acc +FORCE_INLINE void nfms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) { + acc.reg = vec_nmsub(a.reg, b.reg, acc.reg); +} + +FORCE_INLINE void nfms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) { + acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]); + acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]); +} + +FORCE_INLINE void nfms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) { + acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]); + acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]); + acc.reg.val[2] = vec_nmsub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]); + acc.reg.val[3] = vec_nmsub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]); +} + const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31}; const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff, @@ -441,13 +739,24 @@ const static __vector unsigned int one = {1, 1, 1, 1}; inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + __vector unsigned int lsb0 = inp0 >> sh16; + __vector unsigned int lsb1 = inp1 >> sh16; + lsb0 = lsb0 & one; + lsb1 = lsb1 & one; + __vector unsigned int rnd0 = lsb0 + bias; + __vector unsigned int rnd1 = lsb1 + bias; + inp0 = inp0 + rnd0; + inp1 = inp1 + rnd1; int cc; __vector __bool int sel0 = vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc); __vector __bool int sel1 = vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc); - inp0 = vec_sel(inp0, nan, sel0) >> sh16; - inp1 = vec_sel(inp1, nan, sel1) >> sh16; + inp0 = vec_sel(inp0, nan, sel0); + inp1 = vec_sel(inp1, nan, sel1); + inp0 = inp0 >> sh16; + inp1 = inp1 >> sh16; + reg = (__vector signed short)vec_perm(inp0, inp1, omask); } @@ -456,6 +765,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); __vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]); __vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]); + __vector unsigned int lsb0 = inp0 >> sh16; + __vector unsigned int lsb1 = inp1 >> sh16; + __vector unsigned int lsb2 = inp2 >> sh16; + __vector unsigned int lsb3 = inp3 >> sh16; + lsb0 = lsb0 & one; + lsb1 = lsb1 & one; + lsb2 = lsb2 & one; + lsb3 = lsb3 & one; + __vector unsigned int rnd0 = lsb0 + bias; + __vector unsigned int rnd1 = lsb1 + bias; + __vector unsigned int rnd2 = lsb2 + bias; + __vector unsigned int rnd3 = lsb3 + bias; + inp0 = inp0 + rnd0; + inp1 = inp1 + rnd1; + inp2 = inp2 + rnd2; + inp3 = inp3 + rnd3; int cc; __vector __bool int sel0 = vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc); @@ -465,15 +790,164 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc); __vector __bool int sel3 = vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc); - inp0 = vec_sel(inp0, nan, sel0) >> sh16; - inp1 = vec_sel(inp1, nan, sel1) >> sh16; - inp2 = vec_sel(inp2, nan, sel2) >> sh16; - inp3 = vec_sel(inp3, nan, sel3) >> sh16; + inp0 = vec_sel(inp0, nan, sel0); + inp1 = vec_sel(inp1, nan, sel1); + inp2 = vec_sel(inp2, nan, sel2); + inp3 = vec_sel(inp3, nan, sel3); + inp0 = inp0 >> sh16; + inp1 = inp1 >> sh16; + inp2 = inp2 >> sh16; + inp3 = inp3 >> sh16; + reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask); reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask); } -inline void prefetch(const void* addr) { void __dcbt(const void* addr); } +// 1D softmax over `n` elements in `input`, writes result to `output`. +// Uses FP32Vec8 for main body, scalar tail handling. +// Requirement: n > 0 +FORCE_INLINE void softmax_fp32vec8(float* output, const float* input, int n) { + if (n <= 0) return; + + // ---------- Pass 1: find max ---------- + float max_val = -std::numeric_limits::infinity(); + int i = 0; + + for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) { + FP32Vec8 v(input + i); + FP32Vec8::AliasReg ar; + ar.reg = v.reg; + for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) { + if (ar.values[j] > max_val) max_val = ar.values[j]; + } + } + for (; i < n; ++i) { + if (input[i] > max_val) max_val = input[i]; + } + + // ---------- Pass 2: compute exp(x - max) and sum ---------- + float sum = 0.0f; + i = 0; + + for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) { + float tmp[FP32Vec8::VEC_ELEM_NUM]; + for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) { + tmp[j] = input[i + j] - max_val; + } + + FP32Vec8 v(tmp); + FP32Vec8 e = v.exp(); + + FP32Vec8::AliasReg ar; + ar.reg = e.reg; + for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) { + output[i + j] = ar.values[j]; + sum += ar.values[j]; + } + } + + // Tail + for (; i < n; ++i) { + float x = input[i] - max_val; + float ex = std::exp(x); // scalar tail + output[i] = ex; + sum += ex; + } + + // ---------- Pass 3: normalize ---------- + float inv_sum = 1.0f / sum; + i = 0; + + for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) { + float tmp[FP32Vec8::VEC_ELEM_NUM]; + for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) { + tmp[j] = output[i + j] * inv_sum; + } + FP32Vec8 v(tmp); + v.save(output + i); + } + + for (; i < n; ++i) { + output[i] *= inv_sum; + } +} + +// 1D RMSNorm kernel: +// input: x[0..n-1] +// weight: w[0..n-1] (gamma), may be nullptr +// output: y[i] = x[i] * inv_rms * (weight[i] if weight != nullptr else 1) +// eps: small epsilon for numerical stability +FORCE_INLINE void rmsnorm_fp32vec8(float* output, const float* input, + const float* weight, int n, float eps) { + if (n <= 0) return; + + // ---------- Pass 1: compute sum of squares ---------- + float sum_sq = 0.0f; + int i = 0; + + for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) { + FP32Vec8 x_vec(input + i); + + FP32Vec8 sq = x_vec * x_vec; + + FP32Vec8::AliasReg ar; + ar.reg = sq.reg; + for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) { + sum_sq += ar.values[j]; + } + } + + // Tail + for (; i < n; ++i) { + float v = input[i]; + sum_sq += v * v; + } + + float mean_sq = sum_sq / static_cast(n); + float inv_rms = 1.0f / std::sqrt(mean_sq + eps); + + // ---------- Pass 2: scale (and apply weight if given) ---------- + const float inv_rms_f = inv_rms; + i = 0; + + if (weight) { + // with gamma + for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) { + FP32Vec8 x_vec(input + i); + + float wtmp[FP32Vec8::VEC_ELEM_NUM]; + for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) { + wtmp[j] = weight[i + j]; + } + FP32Vec8 w_vec(wtmp); + + FP32Vec8 scale_vec(inv_rms_f); + FP32Vec8 y = x_vec * scale_vec * w_vec; + y.save(output + i); + } + + for (; i < n; ++i) { + output[i] = input[i] * inv_rms_f * weight[i]; + } + } else { + // without gamma + for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) { + FP32Vec8 x_vec(input + i); + FP32Vec8 scale_vec(inv_rms_f); + FP32Vec8 y = x_vec * scale_vec; + y.save(output + i); + } + + for (; i < n; ++i) { + output[i] = input[i] * inv_rms_f; + } + } +} + +// Prefetch data to cache for better memory access performance +FORCE_INLINE void prefetch(const void* addr) { + __builtin_prefetch(addr, 0, 3); // 0=read, 3=high temporal locality +} }; // namespace vec_op