diff --git a/csrc/cpu/mla_decode.cpp b/csrc/cpu/mla_decode.cpp index bc0ac5bc5..bd489b463 100644 --- a/csrc/cpu/mla_decode.cpp +++ b/csrc/cpu/mla_decode.cpp @@ -38,6 +38,15 @@ struct KernelVecType { using qk_vec_type = vec_op::BF16Vec32; using v_load_vec_type = vec_op::BF16Vec16; }; + +#elif defined(__s390x__) +template <> +struct KernelVecType { + using qk_load_vec_type = vec_op::BF16Vec16; + using qk_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; + #elif defined(__aarch64__) template <> struct KernelVecType {