[CPU] Split attention dispatch by head_dim alignment (#32161)
Signed-off-by: Rehan Khan <Rehan.Khan7@ibm.com>
This commit is contained in:
@@ -359,6 +359,19 @@ else()
|
|||||||
add_compile_definitions(-DVLLM_NUMA_DISABLED)
|
add_compile_definitions(-DVLLM_NUMA_DISABLED)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# Generate CPU attention dispatch header
|
||||||
|
#
|
||||||
|
message(STATUS "Generating CPU attention dispatch header")
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/csrc/cpu/generate_cpu_attn_dispatch.py
|
||||||
|
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/csrc/cpu
|
||||||
|
RESULT_VARIABLE GEN_RESULT
|
||||||
|
)
|
||||||
|
if(NOT GEN_RESULT EQUAL 0)
|
||||||
|
message(FATAL_ERROR "Failed to generate CPU attention dispatch header")
|
||||||
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# _C extension
|
# _C extension
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -1,79 +1,4 @@
|
|||||||
#include "cpu_attn_vec.hpp"
|
#include "cpu_attn_dispatch_generated.h"
|
||||||
#include "cpu_attn_vec16.hpp"
|
|
||||||
|
|
||||||
#ifdef CPU_CAPABILITY_AMXBF16
|
|
||||||
#include "cpu_attn_amx.hpp"
|
|
||||||
#define AMX_DISPATCH(...) \
|
|
||||||
case cpu_attention::ISA::AMX: { \
|
|
||||||
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::AMX, \
|
|
||||||
scalar_t, head_dim>; \
|
|
||||||
return __VA_ARGS__(); \
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef __aarch64__
|
|
||||||
#include "cpu_attn_neon.hpp"
|
|
||||||
// NEON requires head_dim to be a multiple of 32
|
|
||||||
#define NEON_DISPATCH(...) \
|
|
||||||
case cpu_attention::ISA::NEON: { \
|
|
||||||
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
|
|
||||||
scalar_t, head_dim>; \
|
|
||||||
return __VA_ARGS__(); \
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
#define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
|
|
||||||
#endif // #ifdef __aarch64__
|
|
||||||
|
|
||||||
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
|
|
||||||
case HEAD_DIM: { \
|
|
||||||
constexpr size_t head_dim = HEAD_DIM; \
|
|
||||||
return __VA_ARGS__(); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \
|
|
||||||
[&] { \
|
|
||||||
switch (HEAD_DIM) { \
|
|
||||||
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
|
|
||||||
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
|
|
||||||
CPU_ATTN_DISPATCH_CASE(80, __VA_ARGS__) \
|
|
||||||
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
|
|
||||||
CPU_ATTN_DISPATCH_CASE(112, __VA_ARGS__) \
|
|
||||||
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
|
|
||||||
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
|
|
||||||
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
|
|
||||||
CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \
|
|
||||||
CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \
|
|
||||||
default: { \
|
|
||||||
TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \
|
|
||||||
std::to_string(HEAD_DIM)); \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
}()
|
|
||||||
|
|
||||||
#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \
|
|
||||||
[&] { \
|
|
||||||
switch (ISA_TYPE) { \
|
|
||||||
AMX_DISPATCH(__VA_ARGS__) \
|
|
||||||
NEON_DISPATCH(__VA_ARGS__) \
|
|
||||||
case cpu_attention::ISA::VEC: { \
|
|
||||||
using attn_impl = \
|
|
||||||
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
|
|
||||||
head_dim>; \
|
|
||||||
return __VA_ARGS__(); \
|
|
||||||
} \
|
|
||||||
case cpu_attention::ISA::VEC16: { \
|
|
||||||
using attn_impl = \
|
|
||||||
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC16, scalar_t, \
|
|
||||||
head_dim>; \
|
|
||||||
return __VA_ARGS__(); \
|
|
||||||
} \
|
|
||||||
default: { \
|
|
||||||
TORCH_CHECK(false, "Invalid CPU attention ISA type."); \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
}()
|
|
||||||
|
|
||||||
torch::Tensor get_scheduler_metadata(
|
torch::Tensor get_scheduler_metadata(
|
||||||
const int64_t num_req, const int64_t num_heads_q,
|
const int64_t num_req, const int64_t num_heads_q,
|
||||||
@@ -122,16 +47,14 @@ torch::Tensor get_scheduler_metadata(
|
|||||||
input.enable_kv_split = enable_kv_split;
|
input.enable_kv_split = enable_kv_split;
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
|
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
|
||||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
|
CPU_ATTN_DISPATCH(head_dim, isa, [&]() {
|
||||||
CPU_ATTN_DISPATCH_IMPL(isa, [&]() {
|
input.elem_size = sizeof(scalar_t);
|
||||||
input.elem_size = sizeof(scalar_t);
|
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
|
||||||
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
|
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
|
||||||
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
|
input.output_buffer_elem_size =
|
||||||
input.output_buffer_elem_size =
|
sizeof(attn_impl::partial_output_buffer_t);
|
||||||
sizeof(attn_impl::partial_output_buffer_t);
|
input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration;
|
||||||
input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration;
|
input.kv_block_alignment = attn_impl::BlockSizeAlignment;
|
||||||
input.kv_block_alignment = attn_impl::BlockSizeAlignment;
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -184,18 +107,14 @@ void cpu_attn_reshape_and_cache(
|
|||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
|
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
|
||||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
|
CPU_ATTN_DISPATCH(head_dim, isa_tag, [&]() {
|
||||||
CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() {
|
attn_impl::reshape_and_cache(
|
||||||
attn_impl::reshape_and_cache(
|
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||||
key_cache.data_ptr<scalar_t>(),
|
slot_mapping.data_ptr<int64_t>(), token_num, key_token_num_stride,
|
||||||
value_cache.data_ptr<scalar_t>(),
|
value_token_num_stride, head_num, key_head_num_stride,
|
||||||
slot_mapping.data_ptr<int64_t>(), token_num,
|
value_head_num_stride, num_blocks, num_blocks_stride,
|
||||||
key_token_num_stride, value_token_num_stride, head_num,
|
cache_head_num_stride, block_size, block_size_stride);
|
||||||
key_head_num_stride, value_head_num_stride, num_blocks,
|
|
||||||
num_blocks_stride, cache_head_num_stride, block_size,
|
|
||||||
block_size_stride);
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -257,12 +176,10 @@ void cpu_attention_with_kv_cache(
|
|||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
|
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
|
||||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] {
|
CPU_ATTN_DISPATCH(query.size(2), input.metadata->isa, [&]() {
|
||||||
CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() {
|
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
|
||||||
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
|
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
|
||||||
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
|
mainloop(&input);
|
||||||
mainloop(&input);
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -377,7 +377,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
|
|||||||
const int32_t q_heads_per_kv, const int64_t q_num_stride,
|
const int32_t q_heads_per_kv, const int64_t q_num_stride,
|
||||||
const int64_t q_head_stride, const float scale) {
|
const int64_t q_head_stride, const float scale) {
|
||||||
constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t);
|
constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t);
|
||||||
// static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
|
static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
|
||||||
constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES;
|
constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES;
|
||||||
constexpr int64_t head_elem_num_pre_block =
|
constexpr int64_t head_elem_num_pre_block =
|
||||||
AMX_TILE_ROW_BYTES / sizeof(scalar_t);
|
AMX_TILE_ROW_BYTES / sizeof(scalar_t);
|
||||||
|
|||||||
@@ -264,7 +264,7 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
|
|||||||
constexpr static ISA ISAType = ISA::NEON;
|
constexpr static ISA ISAType = ISA::NEON;
|
||||||
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
|
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
|
||||||
|
|
||||||
// static_assert(HeadDim % HeadDimAlignment == 0);
|
static_assert(HeadDim % HeadDimAlignment == 0);
|
||||||
// the gemm micro kernel is Mx8
|
// the gemm micro kernel is Mx8
|
||||||
static_assert(HeadDimAlignment % 8 == 0);
|
static_assert(HeadDimAlignment % 8 == 0);
|
||||||
static_assert(BlockSizeAlignment % 8 == 0);
|
static_assert(BlockSizeAlignment % 8 == 0);
|
||||||
|
|||||||
203
csrc/cpu/generate_cpu_attn_dispatch.py
Normal file
203
csrc/cpu/generate_cpu_attn_dispatch.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Generate CPU attention dispatch switch cases and kernel instantiations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Head dimensions divisible by 32 (support all ISAs)
|
||||||
|
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||||
|
|
||||||
|
# Head dimensions divisible by 16 but not 32 (VEC16 only)
|
||||||
|
HEAD_DIMS_16 = [80, 112]
|
||||||
|
|
||||||
|
# ISA types
|
||||||
|
ISA_TYPES = {
|
||||||
|
"AMX": 0,
|
||||||
|
"VEC": 1,
|
||||||
|
"VEC16": 2,
|
||||||
|
"NEON": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ISAs supported for head_dims divisible by 32
|
||||||
|
ISA_FOR_32 = ["AMX", "NEON", "VEC", "VEC16"]
|
||||||
|
|
||||||
|
# ISAs supported for head_dims divisible by 16 only
|
||||||
|
ISA_FOR_16 = ["VEC16"]
|
||||||
|
|
||||||
|
|
||||||
|
def encode_params(head_dim: int, isa_type: str) -> int:
|
||||||
|
"""Encode head_dim and ISA type into a single int64_t."""
|
||||||
|
isa_val = ISA_TYPES[isa_type]
|
||||||
|
# Encoding: (head_dim << 8) | isa_type
|
||||||
|
# This allows head_dim up to 2^56 - 1 and 256 ISA types
|
||||||
|
return (head_dim << 8) | isa_val
|
||||||
|
|
||||||
|
|
||||||
|
def generate_cases_for_isa_group(isa_list: list[str]) -> str:
|
||||||
|
"""Generate switch cases for a specific ISA group."""
|
||||||
|
cases = []
|
||||||
|
|
||||||
|
# Generate cases for head_dims divisible by 32
|
||||||
|
for head_dim in HEAD_DIMS_32:
|
||||||
|
for isa in isa_list:
|
||||||
|
if isa not in ISA_FOR_32:
|
||||||
|
continue
|
||||||
|
encoded = encode_params(head_dim, isa)
|
||||||
|
case_str = (
|
||||||
|
f""" case {encoded}LL: {{ """
|
||||||
|
f"""/* head_dim={head_dim}, isa={isa} */ \\"""
|
||||||
|
f"""
|
||||||
|
constexpr size_t head_dim = {head_dim}; \\"""
|
||||||
|
f"""
|
||||||
|
using attn_impl = cpu_attention::AttentionImpl<"""
|
||||||
|
f"""cpu_attention::ISA::{isa}, \\"""
|
||||||
|
f"""
|
||||||
|
"""
|
||||||
|
f"""scalar_t, head_dim>; \\"""
|
||||||
|
f"""
|
||||||
|
return __VA_ARGS__(); \\"""
|
||||||
|
f"""
|
||||||
|
}} \\"""
|
||||||
|
)
|
||||||
|
cases.append(case_str)
|
||||||
|
|
||||||
|
# Generate cases for head_dims divisible by 16 only
|
||||||
|
for head_dim in HEAD_DIMS_16:
|
||||||
|
for isa in isa_list:
|
||||||
|
encoded = encode_params(head_dim, isa)
|
||||||
|
case_str = (
|
||||||
|
f""" case {encoded}LL: {{ """
|
||||||
|
f"""/* head_dim={head_dim}, isa={isa} """
|
||||||
|
f"""(using VEC16) */ \\"""
|
||||||
|
f"""
|
||||||
|
constexpr size_t head_dim = {head_dim}; \\"""
|
||||||
|
f"""
|
||||||
|
using attn_impl = cpu_attention::AttentionImpl<"""
|
||||||
|
f"""cpu_attention::ISA::VEC16, \\"""
|
||||||
|
f"""
|
||||||
|
"""
|
||||||
|
f"""scalar_t, head_dim>; \\"""
|
||||||
|
f"""
|
||||||
|
return __VA_ARGS__(); \\"""
|
||||||
|
f"""
|
||||||
|
}} \\"""
|
||||||
|
)
|
||||||
|
cases.append(case_str)
|
||||||
|
|
||||||
|
return "\n".join(cases)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_helper_function() -> str:
|
||||||
|
"""Generate helper function to encode parameters."""
|
||||||
|
return """
|
||||||
|
inline int64_t encode_cpu_attn_params(int64_t head_dim, cpu_attention::ISA isa) {
|
||||||
|
return (head_dim << 8) | static_cast<int64_t>(isa);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_header_file() -> str:
|
||||||
|
"""Generate the complete header file content."""
|
||||||
|
header = """// auto generated by generate_cpu_attn_dispatch.py
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
|
#ifndef CPU_ATTN_DISPATCH_GENERATED_H
|
||||||
|
#define CPU_ATTN_DISPATCH_GENERATED_H
|
||||||
|
|
||||||
|
#include "cpu_attn_vec.hpp"
|
||||||
|
#include "cpu_attn_vec16.hpp"
|
||||||
|
|
||||||
|
#ifdef CPU_CAPABILITY_AMXBF16
|
||||||
|
#include "cpu_attn_amx.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __aarch64__
|
||||||
|
#include "cpu_attn_neon.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
header += generate_helper_function()
|
||||||
|
|
||||||
|
# Generate dispatch macro with conditional compilation for different ISA sets
|
||||||
|
header += """
|
||||||
|
// Dispatch macro using encoded parameters
|
||||||
|
"""
|
||||||
|
|
||||||
|
# x86_64 with AMX
|
||||||
|
header += """#if defined(CPU_CAPABILITY_AMXBF16)
|
||||||
|
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
|
||||||
|
[&] { \\
|
||||||
|
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
|
||||||
|
switch (encoded_params) { \\
|
||||||
|
"""
|
||||||
|
header += generate_cases_for_isa_group(["AMX", "VEC", "VEC16"])
|
||||||
|
header += """
|
||||||
|
default: { \\
|
||||||
|
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
|
||||||
|
std::to_string(HEAD_DIM) + " isa=" + \\
|
||||||
|
std::to_string(static_cast<int>(ISA_TYPE))); \\
|
||||||
|
} \\
|
||||||
|
} \\
|
||||||
|
}()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ARM64 with NEON
|
||||||
|
header += """#elif defined(__aarch64__)
|
||||||
|
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
|
||||||
|
[&] { \\
|
||||||
|
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
|
||||||
|
switch (encoded_params) { \\
|
||||||
|
"""
|
||||||
|
header += generate_cases_for_isa_group(["NEON", "VEC", "VEC16"])
|
||||||
|
header += """
|
||||||
|
default: { \\
|
||||||
|
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
|
||||||
|
std::to_string(HEAD_DIM) + " isa=" + \\
|
||||||
|
std::to_string(static_cast<int>(ISA_TYPE))); \\
|
||||||
|
} \\
|
||||||
|
} \\
|
||||||
|
}()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Fallback: VEC and VEC16 only
|
||||||
|
header += """#else
|
||||||
|
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
|
||||||
|
[&] { \\
|
||||||
|
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
|
||||||
|
switch (encoded_params) { \\
|
||||||
|
"""
|
||||||
|
header += generate_cases_for_isa_group(["VEC", "VEC16"])
|
||||||
|
header += """
|
||||||
|
default: { \\
|
||||||
|
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
|
||||||
|
std::to_string(HEAD_DIM) + " isa=" + \\
|
||||||
|
std::to_string(static_cast<int>(ISA_TYPE))); \\
|
||||||
|
} \\
|
||||||
|
} \\
|
||||||
|
}()
|
||||||
|
|
||||||
|
#endif /* CPU_CAPABILITY_AMXBF16 / __aarch64__ */
|
||||||
|
|
||||||
|
#endif // CPU_ATTN_DISPATCH_GENERATED_H
|
||||||
|
"""
|
||||||
|
|
||||||
|
return header
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
output_path = os.path.join(
|
||||||
|
os.path.dirname(__file__), "cpu_attn_dispatch_generated.h"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
f.write(generate_header_file())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -26,6 +26,7 @@ NUM_HEADS = [
|
|||||||
(9, 3),
|
(9, 3),
|
||||||
]
|
]
|
||||||
HEAD_SIZES = [96, 128]
|
HEAD_SIZES = [96, 128]
|
||||||
|
HEAD_SIZES_VEC16 = [96, 80, 112, 128]
|
||||||
QTYPES = [torch.bfloat16, torch.half, torch.float32]
|
QTYPES = [torch.bfloat16, torch.half, torch.float32]
|
||||||
SLIDING_WINDOWS = [None, 256]
|
SLIDING_WINDOWS = [None, 256]
|
||||||
NUM_BLOCKS = [
|
NUM_BLOCKS = [
|
||||||
@@ -432,7 +433,7 @@ def test_varlen_with_paged_kv_normal_amx(
|
|||||||
|
|
||||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES_VEC16)
|
||||||
@pytest.mark.parametrize("block_size", [48])
|
@pytest.mark.parametrize("block_size", [48])
|
||||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
|||||||
Reference in New Issue
Block a user