2026-02-04 09:07:15 +05:30
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
"""
|
|
|
|
|
Generate CPU attention dispatch switch cases and kernel instantiations.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
# Head dimensions divisible by 32 (support all ISAs)
|
|
|
|
|
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256]
|
|
|
|
|
|
|
|
|
|
# Head dimensions divisible by 16 but not 32 (VEC16 only)
|
|
|
|
|
HEAD_DIMS_16 = [80, 112]
|
|
|
|
|
|
|
|
|
|
# ISA types
|
|
|
|
|
ISA_TYPES = {
|
|
|
|
|
"AMX": 0,
|
|
|
|
|
"VEC": 1,
|
|
|
|
|
"VEC16": 2,
|
|
|
|
|
"NEON": 3,
|
2026-02-24 20:55:39 +05:30
|
|
|
"VXE": 4,
|
2026-02-04 09:07:15 +05:30
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# ISAs supported for head_dims divisible by 32
|
2026-02-24 20:55:39 +05:30
|
|
|
ISA_FOR_32 = ["AMX", "NEON", "VEC", "VEC16", "VXE"]
|
2026-02-04 09:07:15 +05:30
|
|
|
|
|
|
|
|
# ISAs supported for head_dims divisible by 16 only
|
|
|
|
|
ISA_FOR_16 = ["VEC16"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_params(head_dim: int, isa_type: str) -> int:
|
|
|
|
|
"""Encode head_dim and ISA type into a single int64_t."""
|
|
|
|
|
isa_val = ISA_TYPES[isa_type]
|
|
|
|
|
# Encoding: (head_dim << 8) | isa_type
|
|
|
|
|
# This allows head_dim up to 2^56 - 1 and 256 ISA types
|
|
|
|
|
return (head_dim << 8) | isa_val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_cases_for_isa_group(isa_list: list[str]) -> str:
|
|
|
|
|
"""Generate switch cases for a specific ISA group."""
|
|
|
|
|
cases = []
|
|
|
|
|
|
|
|
|
|
# Generate cases for head_dims divisible by 32
|
|
|
|
|
for head_dim in HEAD_DIMS_32:
|
|
|
|
|
for isa in isa_list:
|
|
|
|
|
if isa not in ISA_FOR_32:
|
|
|
|
|
continue
|
|
|
|
|
encoded = encode_params(head_dim, isa)
|
|
|
|
|
case_str = (
|
|
|
|
|
f""" case {encoded}LL: {{ """
|
|
|
|
|
f"""/* head_dim={head_dim}, isa={isa} */ \\"""
|
|
|
|
|
f"""
|
|
|
|
|
constexpr size_t head_dim = {head_dim}; \\"""
|
|
|
|
|
f"""
|
|
|
|
|
using attn_impl = cpu_attention::AttentionImpl<"""
|
|
|
|
|
f"""cpu_attention::ISA::{isa}, \\"""
|
|
|
|
|
f"""
|
|
|
|
|
"""
|
|
|
|
|
f"""scalar_t, head_dim>; \\"""
|
|
|
|
|
f"""
|
|
|
|
|
return __VA_ARGS__(); \\"""
|
|
|
|
|
f"""
|
|
|
|
|
}} \\"""
|
|
|
|
|
)
|
|
|
|
|
cases.append(case_str)
|
|
|
|
|
|
|
|
|
|
# Generate cases for head_dims divisible by 16 only
|
|
|
|
|
for head_dim in HEAD_DIMS_16:
|
|
|
|
|
for isa in isa_list:
|
|
|
|
|
encoded = encode_params(head_dim, isa)
|
|
|
|
|
case_str = (
|
|
|
|
|
f""" case {encoded}LL: {{ """
|
|
|
|
|
f"""/* head_dim={head_dim}, isa={isa} """
|
|
|
|
|
f"""(using VEC16) */ \\"""
|
|
|
|
|
f"""
|
|
|
|
|
constexpr size_t head_dim = {head_dim}; \\"""
|
|
|
|
|
f"""
|
|
|
|
|
using attn_impl = cpu_attention::AttentionImpl<"""
|
|
|
|
|
f"""cpu_attention::ISA::VEC16, \\"""
|
|
|
|
|
f"""
|
|
|
|
|
"""
|
|
|
|
|
f"""scalar_t, head_dim>; \\"""
|
|
|
|
|
f"""
|
|
|
|
|
return __VA_ARGS__(); \\"""
|
|
|
|
|
f"""
|
|
|
|
|
}} \\"""
|
|
|
|
|
)
|
|
|
|
|
cases.append(case_str)
|
|
|
|
|
|
|
|
|
|
return "\n".join(cases)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_helper_function() -> str:
|
|
|
|
|
"""Generate helper function to encode parameters."""
|
|
|
|
|
return """
|
|
|
|
|
inline int64_t encode_cpu_attn_params(int64_t head_dim, cpu_attention::ISA isa) {
|
|
|
|
|
return (head_dim << 8) | static_cast<int64_t>(isa);
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_header_file() -> str:
|
|
|
|
|
"""Generate the complete header file content."""
|
|
|
|
|
header = """// auto generated by generate_cpu_attn_dispatch.py
|
|
|
|
|
// clang-format off
|
|
|
|
|
|
|
|
|
|
#ifndef CPU_ATTN_DISPATCH_GENERATED_H
|
|
|
|
|
#define CPU_ATTN_DISPATCH_GENERATED_H
|
|
|
|
|
|
|
|
|
|
#include "cpu_attn_vec.hpp"
|
|
|
|
|
#include "cpu_attn_vec16.hpp"
|
|
|
|
|
|
|
|
|
|
#ifdef CPU_CAPABILITY_AMXBF16
|
|
|
|
|
#include "cpu_attn_amx.hpp"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifdef __aarch64__
|
|
|
|
|
#include "cpu_attn_neon.hpp"
|
|
|
|
|
#endif
|
|
|
|
|
|
2026-02-24 20:55:39 +05:30
|
|
|
#ifdef __s390x__
|
|
|
|
|
#include "cpu_attn_vxe.hpp"
|
|
|
|
|
#endif
|
|
|
|
|
|
2026-02-04 09:07:15 +05:30
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
header += generate_helper_function()
|
|
|
|
|
|
|
|
|
|
# Generate dispatch macro with conditional compilation for different ISA sets
|
|
|
|
|
header += """
|
|
|
|
|
// Dispatch macro using encoded parameters
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# x86_64 with AMX
|
|
|
|
|
header += """#if defined(CPU_CAPABILITY_AMXBF16)
|
|
|
|
|
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
|
|
|
|
|
[&] { \\
|
|
|
|
|
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
|
|
|
|
|
switch (encoded_params) { \\
|
|
|
|
|
"""
|
|
|
|
|
header += generate_cases_for_isa_group(["AMX", "VEC", "VEC16"])
|
|
|
|
|
header += """
|
|
|
|
|
default: { \\
|
|
|
|
|
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
|
|
|
|
|
std::to_string(HEAD_DIM) + " isa=" + \\
|
|
|
|
|
std::to_string(static_cast<int>(ISA_TYPE))); \\
|
|
|
|
|
} \\
|
|
|
|
|
} \\
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# ARM64 with NEON
|
|
|
|
|
header += """#elif defined(__aarch64__)
|
|
|
|
|
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
|
|
|
|
|
[&] { \\
|
|
|
|
|
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
|
|
|
|
|
switch (encoded_params) { \\
|
|
|
|
|
"""
|
|
|
|
|
header += generate_cases_for_isa_group(["NEON", "VEC", "VEC16"])
|
|
|
|
|
header += """
|
|
|
|
|
default: { \\
|
|
|
|
|
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
|
|
|
|
|
std::to_string(HEAD_DIM) + " isa=" + \\
|
|
|
|
|
std::to_string(static_cast<int>(ISA_TYPE))); \\
|
|
|
|
|
} \\
|
|
|
|
|
} \\
|
|
|
|
|
}()
|
|
|
|
|
|
2026-02-24 20:55:39 +05:30
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# s390x with VXE
|
|
|
|
|
header += """#elif defined(__s390x__)
|
|
|
|
|
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
|
|
|
|
|
[&] { \\
|
|
|
|
|
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
|
|
|
|
|
switch (encoded_params) { \\
|
|
|
|
|
"""
|
|
|
|
|
header += generate_cases_for_isa_group(["VXE", "VEC", "VEC16"])
|
|
|
|
|
header += """
|
|
|
|
|
default: { \\
|
|
|
|
|
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
|
|
|
|
|
std::to_string(HEAD_DIM) + " isa=" + \\
|
|
|
|
|
std::to_string(static_cast<int>(ISA_TYPE))); \\
|
|
|
|
|
} \\
|
|
|
|
|
} \\
|
|
|
|
|
}()
|
|
|
|
|
|
2026-02-04 09:07:15 +05:30
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Fallback: VEC and VEC16 only
|
|
|
|
|
header += """#else
|
|
|
|
|
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
|
|
|
|
|
[&] { \\
|
|
|
|
|
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
|
|
|
|
|
switch (encoded_params) { \\
|
|
|
|
|
"""
|
|
|
|
|
header += generate_cases_for_isa_group(["VEC", "VEC16"])
|
|
|
|
|
header += """
|
|
|
|
|
default: { \\
|
|
|
|
|
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
|
|
|
|
|
std::to_string(HEAD_DIM) + " isa=" + \\
|
|
|
|
|
std::to_string(static_cast<int>(ISA_TYPE))); \\
|
|
|
|
|
} \\
|
|
|
|
|
} \\
|
|
|
|
|
}()
|
|
|
|
|
|
2026-02-24 20:55:39 +05:30
|
|
|
#endif /* CPU_CAPABILITY_AMXBF16 / __aarch64__ / __s390x__ */
|
2026-02-04 09:07:15 +05:30
|
|
|
|
|
|
|
|
#endif // CPU_ATTN_DISPATCH_GENERATED_H
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
return header
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
output_path = os.path.join(
|
|
|
|
|
os.path.dirname(__file__), "cpu_attn_dispatch_generated.h"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
with open(output_path, "w") as f:
|
|
|
|
|
f.write(generate_header_file())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|