#!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Generate CPU attention dispatch switch cases and kernel instantiations. """ import os # Head dimensions divisible by 32 (support all ISAs) HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256] # Head dimensions divisible by 16 but not 32 (VEC16 only) HEAD_DIMS_16 = [80, 112] # ISA types ISA_TYPES = { "AMX": 0, "VEC": 1, "VEC16": 2, "NEON": 3, "VXE": 4, } # ISAs supported for head_dims divisible by 32 ISA_FOR_32 = ["AMX", "NEON", "VEC", "VEC16", "VXE"] # ISAs supported for head_dims divisible by 16 only ISA_FOR_16 = ["VEC16"] def encode_params(head_dim: int, isa_type: str) -> int: """Encode head_dim and ISA type into a single int64_t.""" isa_val = ISA_TYPES[isa_type] # Encoding: (head_dim << 8) | isa_type # This allows head_dim up to 2^56 - 1 and 256 ISA types return (head_dim << 8) | isa_val def generate_cases_for_isa_group(isa_list: list[str]) -> str: """Generate switch cases for a specific ISA group.""" cases = [] # Generate cases for head_dims divisible by 32 for head_dim in HEAD_DIMS_32: for isa in isa_list: if isa not in ISA_FOR_32: continue encoded = encode_params(head_dim, isa) case_str = ( f""" case {encoded}LL: {{ """ f"""/* head_dim={head_dim}, isa={isa} */ \\""" f""" constexpr size_t head_dim = {head_dim}; \\""" f""" using attn_impl = cpu_attention::AttentionImpl<""" f"""cpu_attention::ISA::{isa}, \\""" f""" """ f"""scalar_t, head_dim>; \\""" f""" return __VA_ARGS__(); \\""" f""" }} \\""" ) cases.append(case_str) # Generate cases for head_dims divisible by 16 only for head_dim in HEAD_DIMS_16: for isa in isa_list: encoded = encode_params(head_dim, isa) case_str = ( f""" case {encoded}LL: {{ """ f"""/* head_dim={head_dim}, isa={isa} """ f"""(using VEC16) */ \\""" f""" constexpr size_t head_dim = {head_dim}; \\""" f""" using attn_impl = cpu_attention::AttentionImpl<""" f"""cpu_attention::ISA::VEC16, \\""" f""" """ f"""scalar_t, head_dim>; \\""" f""" return __VA_ARGS__(); \\""" f""" }} \\""" ) cases.append(case_str) return "\n".join(cases) def generate_helper_function() -> str: """Generate helper function to encode parameters.""" return """ inline int64_t encode_cpu_attn_params(int64_t head_dim, cpu_attention::ISA isa) { return (head_dim << 8) | static_cast(isa); } """ def generate_header_file() -> str: """Generate the complete header file content.""" header = """// auto generated by generate_cpu_attn_dispatch.py // clang-format off #ifndef CPU_ATTN_DISPATCH_GENERATED_H #define CPU_ATTN_DISPATCH_GENERATED_H #include "cpu_attn_vec.hpp" #include "cpu_attn_vec16.hpp" #ifdef CPU_CAPABILITY_AMXBF16 #include "cpu_attn_amx.hpp" #endif #ifdef __aarch64__ #include "cpu_attn_neon.hpp" #endif #ifdef __s390x__ #include "cpu_attn_vxe.hpp" #endif """ header += generate_helper_function() # Generate dispatch macro with conditional compilation for different ISA sets header += """ // Dispatch macro using encoded parameters """ # x86_64 with AMX header += """#if defined(CPU_CAPABILITY_AMXBF16) #define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\ [&] { \\ int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\ switch (encoded_params) { \\ """ header += generate_cases_for_isa_group(["AMX", "VEC", "VEC16"]) header += """ default: { \\ TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\ std::to_string(HEAD_DIM) + " isa=" + \\ std::to_string(static_cast(ISA_TYPE))); \\ } \\ } \\ }() """ # ARM64 with NEON header += """#elif defined(__aarch64__) #define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\ [&] { \\ int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\ switch (encoded_params) { \\ """ header += generate_cases_for_isa_group(["NEON", "VEC", "VEC16"]) header += """ default: { \\ TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\ std::to_string(HEAD_DIM) + " isa=" + \\ std::to_string(static_cast(ISA_TYPE))); \\ } \\ } \\ }() """ # s390x with VXE header += """#elif defined(__s390x__) #define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\ [&] { \\ int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\ switch (encoded_params) { \\ """ header += generate_cases_for_isa_group(["VXE", "VEC", "VEC16"]) header += """ default: { \\ TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\ std::to_string(HEAD_DIM) + " isa=" + \\ std::to_string(static_cast(ISA_TYPE))); \\ } \\ } \\ }() """ # Fallback: VEC and VEC16 only header += """#else #define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\ [&] { \\ int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\ switch (encoded_params) { \\ """ header += generate_cases_for_isa_group(["VEC", "VEC16"]) header += """ default: { \\ TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\ std::to_string(HEAD_DIM) + " isa=" + \\ std::to_string(static_cast(ISA_TYPE))); \\ } \\ } \\ }() #endif /* CPU_CAPABILITY_AMXBF16 / __aarch64__ / __s390x__ */ #endif // CPU_ATTN_DISPATCH_GENERATED_H """ return header def main(): output_path = os.path.join( os.path.dirname(__file__), "cpu_attn_dispatch_generated.h" ) with open(output_path, "w") as f: f.write(generate_header_file()) if __name__ == "__main__": main()