diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 6da4f6c0c..c9813a73d 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -359,6 +359,19 @@ else() add_compile_definitions(-DVLLM_NUMA_DISABLED) endif() +# +# Generate CPU attention dispatch header +# +message(STATUS "Generating CPU attention dispatch header") +execute_process( + COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/csrc/cpu/generate_cpu_attn_dispatch.py + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/csrc/cpu + RESULT_VARIABLE GEN_RESULT +) +if(NOT GEN_RESULT EQUAL 0) + message(FATAL_ERROR "Failed to generate CPU attention dispatch header") +endif() + # # _C extension # diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp index 374fc2ee6..641f95a2b 100644 --- a/csrc/cpu/cpu_attn.cpp +++ b/csrc/cpu/cpu_attn.cpp @@ -1,79 +1,4 @@ -#include "cpu_attn_vec.hpp" -#include "cpu_attn_vec16.hpp" - -#ifdef CPU_CAPABILITY_AMXBF16 - #include "cpu_attn_amx.hpp" - #define AMX_DISPATCH(...) \ - case cpu_attention::ISA::AMX: { \ - using attn_impl = cpu_attention::AttentionImpl; \ - return __VA_ARGS__(); \ - } -#else - #define AMX_DISPATCH(...) case cpu_attention::ISA::AMX: -#endif - -#ifdef __aarch64__ - #include "cpu_attn_neon.hpp" - // NEON requires head_dim to be a multiple of 32 - #define NEON_DISPATCH(...) \ - case cpu_attention::ISA::NEON: { \ - using attn_impl = cpu_attention::AttentionImpl; \ - return __VA_ARGS__(); \ - } -#else - #define NEON_DISPATCH(...) case cpu_attention::ISA::NEON: -#endif // #ifdef __aarch64__ - -#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \ - case HEAD_DIM: { \ - constexpr size_t head_dim = HEAD_DIM; \ - return __VA_ARGS__(); \ - } - -#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \ - [&] { \ - switch (HEAD_DIM) { \ - CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \ - CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \ - CPU_ATTN_DISPATCH_CASE(80, __VA_ARGS__) \ - CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \ - CPU_ATTN_DISPATCH_CASE(112, __VA_ARGS__) \ - CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \ - CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \ - CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \ - CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \ - CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \ - default: { \ - TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \ - std::to_string(HEAD_DIM)); \ - } \ - } \ - }() - -#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \ - [&] { \ - switch (ISA_TYPE) { \ - AMX_DISPATCH(__VA_ARGS__) \ - NEON_DISPATCH(__VA_ARGS__) \ - case cpu_attention::ISA::VEC: { \ - using attn_impl = \ - cpu_attention::AttentionImpl; \ - return __VA_ARGS__(); \ - } \ - case cpu_attention::ISA::VEC16: { \ - using attn_impl = \ - cpu_attention::AttentionImpl; \ - return __VA_ARGS__(); \ - } \ - default: { \ - TORCH_CHECK(false, "Invalid CPU attention ISA type."); \ - } \ - } \ - }() +#include "cpu_attn_dispatch_generated.h" torch::Tensor get_scheduler_metadata( const int64_t num_req, const int64_t num_heads_q, @@ -122,16 +47,14 @@ torch::Tensor get_scheduler_metadata( input.enable_kv_split = enable_kv_split; VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { - CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { - CPU_ATTN_DISPATCH_IMPL(isa, [&]() { - input.elem_size = sizeof(scalar_t); - input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t); - input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t); - input.output_buffer_elem_size = - sizeof(attn_impl::partial_output_buffer_t); - input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration; - input.kv_block_alignment = attn_impl::BlockSizeAlignment; - }); + CPU_ATTN_DISPATCH(head_dim, isa, [&]() { + input.elem_size = sizeof(scalar_t); + input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t); + input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t); + input.output_buffer_elem_size = + sizeof(attn_impl::partial_output_buffer_t); + input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration; + input.kv_block_alignment = attn_impl::BlockSizeAlignment; }); }); @@ -184,18 +107,14 @@ void cpu_attn_reshape_and_cache( VLLM_DISPATCH_FLOATING_TYPES( key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() { - CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { - CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() { - attn_impl::reshape_and_cache( - key.data_ptr(), value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - slot_mapping.data_ptr(), token_num, - key_token_num_stride, value_token_num_stride, head_num, - key_head_num_stride, value_head_num_stride, num_blocks, - num_blocks_stride, cache_head_num_stride, block_size, - block_size_stride); - }); + CPU_ATTN_DISPATCH(head_dim, isa_tag, [&]() { + attn_impl::reshape_and_cache( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + slot_mapping.data_ptr(), token_num, key_token_num_stride, + value_token_num_stride, head_num, key_head_num_stride, + value_head_num_stride, num_blocks, num_blocks_stride, + cache_head_num_stride, block_size, block_size_stride); }); }); } @@ -257,12 +176,10 @@ void cpu_attention_with_kv_cache( VLLM_DISPATCH_FLOATING_TYPES( query.scalar_type(), "cpu_attention_with_kv_cache", [&]() { - CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] { - CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() { - TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0); - cpu_attention::AttentionMainLoop mainloop; - mainloop(&input); - }); + CPU_ATTN_DISPATCH(query.size(2), input.metadata->isa, [&]() { + TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0); + cpu_attention::AttentionMainLoop mainloop; + mainloop(&input); }); }); } diff --git a/csrc/cpu/cpu_attn_amx.hpp b/csrc/cpu/cpu_attn_amx.hpp index 78be05e8d..8da458b99 100644 --- a/csrc/cpu/cpu_attn_amx.hpp +++ b/csrc/cpu/cpu_attn_amx.hpp @@ -377,7 +377,7 @@ class AttentionImpl { const int32_t q_heads_per_kv, const int64_t q_num_stride, const int64_t q_head_stride, const float scale) { constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t); - // static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0); + static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0); constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES; constexpr int64_t head_elem_num_pre_block = AMX_TILE_ROW_BYTES / sizeof(scalar_t); diff --git a/csrc/cpu/cpu_attn_neon.hpp b/csrc/cpu/cpu_attn_neon.hpp index e9ecd1d32..827f0cfbc 100644 --- a/csrc/cpu/cpu_attn_neon.hpp +++ b/csrc/cpu/cpu_attn_neon.hpp @@ -264,7 +264,7 @@ class AttentionImpl { constexpr static ISA ISAType = ISA::NEON; constexpr static bool scale_on_logits = false; // apply scale on q_buffer - // static_assert(HeadDim % HeadDimAlignment == 0); + static_assert(HeadDim % HeadDimAlignment == 0); // the gemm micro kernel is Mx8 static_assert(HeadDimAlignment % 8 == 0); static_assert(BlockSizeAlignment % 8 == 0); diff --git a/csrc/cpu/generate_cpu_attn_dispatch.py b/csrc/cpu/generate_cpu_attn_dispatch.py new file mode 100644 index 000000000..85f21544d --- /dev/null +++ b/csrc/cpu/generate_cpu_attn_dispatch.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Generate CPU attention dispatch switch cases and kernel instantiations. +""" + +import os + +# Head dimensions divisible by 32 (support all ISAs) +HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256] + +# Head dimensions divisible by 16 but not 32 (VEC16 only) +HEAD_DIMS_16 = [80, 112] + +# ISA types +ISA_TYPES = { + "AMX": 0, + "VEC": 1, + "VEC16": 2, + "NEON": 3, +} + +# ISAs supported for head_dims divisible by 32 +ISA_FOR_32 = ["AMX", "NEON", "VEC", "VEC16"] + +# ISAs supported for head_dims divisible by 16 only +ISA_FOR_16 = ["VEC16"] + + +def encode_params(head_dim: int, isa_type: str) -> int: + """Encode head_dim and ISA type into a single int64_t.""" + isa_val = ISA_TYPES[isa_type] + # Encoding: (head_dim << 8) | isa_type + # This allows head_dim up to 2^56 - 1 and 256 ISA types + return (head_dim << 8) | isa_val + + +def generate_cases_for_isa_group(isa_list: list[str]) -> str: + """Generate switch cases for a specific ISA group.""" + cases = [] + + # Generate cases for head_dims divisible by 32 + for head_dim in HEAD_DIMS_32: + for isa in isa_list: + if isa not in ISA_FOR_32: + continue + encoded = encode_params(head_dim, isa) + case_str = ( + f""" case {encoded}LL: {{ """ + f"""/* head_dim={head_dim}, isa={isa} */ \\""" + f""" + constexpr size_t head_dim = {head_dim}; \\""" + f""" + using attn_impl = cpu_attention::AttentionImpl<""" + f"""cpu_attention::ISA::{isa}, \\""" + f""" + """ + f"""scalar_t, head_dim>; \\""" + f""" + return __VA_ARGS__(); \\""" + f""" + }} \\""" + ) + cases.append(case_str) + + # Generate cases for head_dims divisible by 16 only + for head_dim in HEAD_DIMS_16: + for isa in isa_list: + encoded = encode_params(head_dim, isa) + case_str = ( + f""" case {encoded}LL: {{ """ + f"""/* head_dim={head_dim}, isa={isa} """ + f"""(using VEC16) */ \\""" + f""" + constexpr size_t head_dim = {head_dim}; \\""" + f""" + using attn_impl = cpu_attention::AttentionImpl<""" + f"""cpu_attention::ISA::VEC16, \\""" + f""" + """ + f"""scalar_t, head_dim>; \\""" + f""" + return __VA_ARGS__(); \\""" + f""" + }} \\""" + ) + cases.append(case_str) + + return "\n".join(cases) + + +def generate_helper_function() -> str: + """Generate helper function to encode parameters.""" + return """ +inline int64_t encode_cpu_attn_params(int64_t head_dim, cpu_attention::ISA isa) { + return (head_dim << 8) | static_cast(isa); +} +""" + + +def generate_header_file() -> str: + """Generate the complete header file content.""" + header = """// auto generated by generate_cpu_attn_dispatch.py +// clang-format off + +#ifndef CPU_ATTN_DISPATCH_GENERATED_H +#define CPU_ATTN_DISPATCH_GENERATED_H + +#include "cpu_attn_vec.hpp" +#include "cpu_attn_vec16.hpp" + +#ifdef CPU_CAPABILITY_AMXBF16 + #include "cpu_attn_amx.hpp" +#endif + +#ifdef __aarch64__ + #include "cpu_attn_neon.hpp" +#endif + +""" + + header += generate_helper_function() + + # Generate dispatch macro with conditional compilation for different ISA sets + header += """ +// Dispatch macro using encoded parameters +""" + + # x86_64 with AMX + header += """#if defined(CPU_CAPABILITY_AMXBF16) +#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\ + [&] { \\ + int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\ + switch (encoded_params) { \\ +""" + header += generate_cases_for_isa_group(["AMX", "VEC", "VEC16"]) + header += """ + default: { \\ + TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\ + std::to_string(HEAD_DIM) + " isa=" + \\ + std::to_string(static_cast(ISA_TYPE))); \\ + } \\ + } \\ + }() + +""" + + # ARM64 with NEON + header += """#elif defined(__aarch64__) +#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\ + [&] { \\ + int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\ + switch (encoded_params) { \\ +""" + header += generate_cases_for_isa_group(["NEON", "VEC", "VEC16"]) + header += """ + default: { \\ + TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\ + std::to_string(HEAD_DIM) + " isa=" + \\ + std::to_string(static_cast(ISA_TYPE))); \\ + } \\ + } \\ + }() + +""" + + # Fallback: VEC and VEC16 only + header += """#else +#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\ + [&] { \\ + int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\ + switch (encoded_params) { \\ +""" + header += generate_cases_for_isa_group(["VEC", "VEC16"]) + header += """ + default: { \\ + TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\ + std::to_string(HEAD_DIM) + " isa=" + \\ + std::to_string(static_cast(ISA_TYPE))); \\ + } \\ + } \\ + }() + +#endif /* CPU_CAPABILITY_AMXBF16 / __aarch64__ */ + +#endif // CPU_ATTN_DISPATCH_GENERATED_H +""" + + return header + + +def main(): + output_path = os.path.join( + os.path.dirname(__file__), "cpu_attn_dispatch_generated.h" + ) + + with open(output_path, "w") as f: + f.write(generate_header_file()) + + +if __name__ == "__main__": + main() diff --git a/tests/kernels/attention/test_cpu_attn.py b/tests/kernels/attention/test_cpu_attn.py index ef0099f63..9636dfb95 100644 --- a/tests/kernels/attention/test_cpu_attn.py +++ b/tests/kernels/attention/test_cpu_attn.py @@ -26,6 +26,7 @@ NUM_HEADS = [ (9, 3), ] HEAD_SIZES = [96, 128] +HEAD_SIZES_VEC16 = [96, 80, 112, 128] QTYPES = [torch.bfloat16, torch.half, torch.float32] SLIDING_WINDOWS = [None, 256] NUM_BLOCKS = [ @@ -432,7 +433,7 @@ def test_varlen_with_paged_kv_normal_amx( @pytest.mark.parametrize("seq_lens", SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("head_size", HEAD_SIZES_VEC16) @pytest.mark.parametrize("block_size", [48]) @pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) @pytest.mark.parametrize("dtype", [torch.bfloat16])