[CI/Build] CPU release supports both of AVX2 and AVX512 (#35466)
Signed-off-by: jiang1.li <jiang1.li@intel.com> Co-authored-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -13,28 +13,16 @@ endif()
|
||||
#
|
||||
# Define environment variables for special configurations
|
||||
#
|
||||
set(ENABLE_AVX2 $ENV{VLLM_CPU_AVX2})
|
||||
set(ENABLE_AVX512 $ENV{VLLM_CPU_AVX512})
|
||||
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
|
||||
set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16})
|
||||
set(ENABLE_X86_ISA $ENV{VLLM_CPU_X86})
|
||||
set(ENABLE_ARM_BF16 $ENV{VLLM_CPU_ARM_BF16})
|
||||
|
||||
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
||||
|
||||
|
||||
set (ENABLE_NUMA TRUE)
|
||||
|
||||
#
|
||||
# Check the compile flags
|
||||
#
|
||||
|
||||
if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-mf16c"
|
||||
)
|
||||
endif()
|
||||
|
||||
if(MACOSX_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-DVLLM_CPU_EXTENSION")
|
||||
@@ -78,18 +66,6 @@ function(check_sysctl TARGET OUT)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
||||
function (is_avx512_disabled OUT)
|
||||
set(DISABLE_AVX512 $ENV{VLLM_CPU_DISABLE_AVX512})
|
||||
if(DISABLE_AVX512 AND DISABLE_AVX512 STREQUAL "true")
|
||||
set(${OUT} ON PARENT_SCOPE)
|
||||
else()
|
||||
set(${OUT} OFF PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
is_avx512_disabled(AVX512_DISABLED)
|
||||
|
||||
if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
message(STATUS "Apple Silicon Detected")
|
||||
set(APPLE_SILICON_FOUND TRUE)
|
||||
@@ -97,8 +73,6 @@ if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
check_sysctl(hw.optional.neon ASIMD_FOUND)
|
||||
check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND)
|
||||
else()
|
||||
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
|
||||
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
|
||||
find_isa(${CPUINFO} "Power11" POWER11_FOUND)
|
||||
find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
|
||||
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
|
||||
@@ -108,77 +82,32 @@ else()
|
||||
find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support
|
||||
|
||||
# Support cross-compilation by allowing override via environment variables
|
||||
if (ENABLE_AVX2)
|
||||
set(AVX2_FOUND ON)
|
||||
message(STATUS "AVX2 support enabled via VLLM_CPU_AVX2 environment variable")
|
||||
endif()
|
||||
if (ENABLE_AVX512)
|
||||
set(AVX512_FOUND ON)
|
||||
message(STATUS "AVX512 support enabled via VLLM_CPU_AVX512 environment variable")
|
||||
endif()
|
||||
if (ENABLE_ARM_BF16)
|
||||
set(ARM_BF16_FOUND ON)
|
||||
message(STATUS "ARM BF16 support enabled via VLLM_CPU_ARM_BF16 environment variable")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64" OR ENABLE_X86_ISA)
|
||||
set(ENABLE_X86_ISA ON)
|
||||
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3))
|
||||
message(FATAL_ERROR "X86 backend requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mf16c")
|
||||
list(APPEND CXX_COMPILE_FLAGS_AVX512 ${CXX_COMPILE_FLAGS})
|
||||
list(APPEND CXX_COMPILE_FLAGS_AVX2 ${CXX_COMPILE_FLAGS})
|
||||
list(APPEND CXX_COMPILE_FLAGS_AVX512
|
||||
"-mavx512f"
|
||||
"-mavx512vl"
|
||||
"-mavx512bw"
|
||||
"-mavx512dq")
|
||||
|
||||
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
|
||||
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
|
||||
endif()
|
||||
|
||||
find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND)
|
||||
if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
|
||||
set(ENABLE_AVX512VNNI ON)
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
|
||||
endif()
|
||||
|
||||
find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND)
|
||||
if (AMXBF16_FOUND OR ENABLE_AMXBF16)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile")
|
||||
set(ENABLE_AMXBF16 ON)
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AMXBF16)
|
||||
else()
|
||||
set(ENABLE_AMXBF16 OFF)
|
||||
message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AMXBF16 OFF)
|
||||
message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.")
|
||||
endif()
|
||||
|
||||
elseif (AVX2_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
|
||||
message(WARNING "vLLM CPU backend using AVX2 ISA")
|
||||
|
||||
"-mavx512dq"
|
||||
"-mavx512bf16"
|
||||
"-mavx512vnni"
|
||||
"-mamx-bf16"
|
||||
"-mamx-tile")
|
||||
list(APPEND CXX_COMPILE_FLAGS_AVX2
|
||||
"-mavx2")
|
||||
elseif (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||
message(STATUS "PowerPC detected")
|
||||
if (POWER9_FOUND)
|
||||
@@ -219,12 +148,12 @@ elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64")
|
||||
list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc")
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
|
||||
message(FATAL_ERROR "vLLM CPU backend requires X86, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
|
||||
endif()
|
||||
|
||||
|
||||
# Build oneDNN for GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||
# Build oneDNN for GEMM kernels
|
||||
if (ENABLE_X86_ISA OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||
# Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64
|
||||
# TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN
|
||||
set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "")
|
||||
@@ -329,13 +258,21 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
||||
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
|
||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
||||
set(ONEDNN_BUILD_GRAPH "OFF")
|
||||
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
|
||||
set(ONEDNN_ENABLE_JIT_PROFILING "ON")
|
||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||
set(ONEDNN_VERBOSE "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "ON")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "ON")
|
||||
set(ONEDNN_VERBOSE "ON")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
# TODO: Refactor this
|
||||
if (ENABLE_X86_ISA)
|
||||
# Note: only enable oneDNN for AVX512
|
||||
list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512})
|
||||
else()
|
||||
list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS})
|
||||
endif()
|
||||
|
||||
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
||||
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
@@ -348,14 +285,20 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
||||
PRIVATE ${oneDNN_SOURCE_DIR}/src
|
||||
)
|
||||
target_link_libraries(dnnl_ext dnnl torch)
|
||||
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
|
||||
target_compile_options(dnnl_ext PRIVATE ${DNNL_COMPILE_FLAGS} -fPIC)
|
||||
list(APPEND LIBS dnnl_ext)
|
||||
set(USE_ONEDNN ON)
|
||||
else()
|
||||
set(USE_ONEDNN OFF)
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
# TODO: Refactor this
|
||||
if (ENABLE_X86_ISA)
|
||||
message(STATUS "CPU extension (AVX512) compile flags: ${CXX_COMPILE_FLAGS_AVX512}")
|
||||
message(STATUS "CPU extension (AVX2) compile flags: ${CXX_COMPILE_FLAGS_AVX2}")
|
||||
else()
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
endif()
|
||||
|
||||
if(ENABLE_NUMA)
|
||||
list(APPEND LIBS numa)
|
||||
@@ -390,25 +333,6 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cpu/cpu_attn.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/shm.cpp"
|
||||
"csrc/cpu/cpu_wna16.cpp"
|
||||
"csrc/cpu/cpu_fused_moe.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/shm.cpp"
|
||||
@@ -421,21 +345,83 @@ if(USE_ONEDNN)
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||
if (ENABLE_X86_ISA)
|
||||
set(VLLM_EXT_SRC_AVX512
|
||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
"csrc/cpu/cpu_wna16.cpp"
|
||||
"csrc/cpu/cpu_fused_moe.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/cpu_attn.cpp"
|
||||
"csrc/cpu/dnnl_kernels.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp"
|
||||
# TODO: Remove these files
|
||||
"csrc/cpu/activation.cpp"
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/mla_decode.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
#
|
||||
set(VLLM_EXT_SRC_AVX2
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/cpu_attn.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp"
|
||||
# TODO: Remove these files
|
||||
"csrc/cpu/activation.cpp"
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/mla_decode.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
|
||||
|
||||
define_extension_target(
|
||||
_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
LIBRARIES ${LIBS}
|
||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||
USE_SABI 3
|
||||
WITH_SOABI
|
||||
)
|
||||
message(STATUS "CPU extension (AVX512) source files: ${VLLM_EXT_SRC_AVX512}")
|
||||
message(STATUS "CPU extension (AVX2) source files: ${VLLM_EXT_SRC_AVX2}")
|
||||
|
||||
define_extension_target(
|
||||
_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_EXT_SRC_AVX512}
|
||||
LIBRARIES ${LIBS}
|
||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512}
|
||||
USE_SABI 3
|
||||
WITH_SOABI
|
||||
)
|
||||
|
||||
# For SGL kernels
|
||||
target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AVX512")
|
||||
# For AMX kernels
|
||||
target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AMXBF16")
|
||||
|
||||
define_extension_target(
|
||||
_C_AVX2
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_EXT_SRC_AVX2}
|
||||
LIBRARIES ${LIBS}
|
||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX2}
|
||||
USE_SABI 3
|
||||
WITH_SOABI
|
||||
)
|
||||
else()
|
||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||
#
|
||||
# Define extension targets
|
||||
#
|
||||
define_extension_target(
|
||||
_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
LIBRARIES ${LIBS}
|
||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||
USE_SABI 3
|
||||
WITH_SOABI
|
||||
)
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C extension.")
|
||||
|
||||
@@ -4,6 +4,10 @@
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
// Note: overwrite the external defination for sharing same name between
|
||||
// libraries use different ISAs.
|
||||
#define TORCH_EXTENSION_NAME _C
|
||||
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||
|
||||
void release_dnnl_matmul_handler(int64_t handler);
|
||||
@@ -324,19 +328,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"str act, str isa) -> ()");
|
||||
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
// CPU utils
|
||||
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
|
||||
cpu_ops.def(
|
||||
ops.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
||||
ops.def(
|
||||
"mla_decode_kvcache("
|
||||
" Tensor! out, Tensor query, Tensor kv_cache,"
|
||||
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
|
||||
cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
|
||||
ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
||||
11
setup.py
11
setup.py
@@ -818,7 +818,7 @@ def _is_xpu() -> bool:
|
||||
|
||||
|
||||
def _build_custom_ops() -> bool:
|
||||
return _is_cuda() or _is_hip() or _is_cpu()
|
||||
return _is_cuda() or _is_hip()
|
||||
|
||||
|
||||
def get_rocm_version():
|
||||
@@ -987,6 +987,15 @@ if _is_cuda():
|
||||
CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
|
||||
)
|
||||
|
||||
if _is_cpu():
|
||||
import platform
|
||||
|
||||
if platform.machine() in ("x86_64", "AMD64"):
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
ext_modules.append(CMakeExtension(name="vllm._C_AVX2"))
|
||||
else:
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
|
||||
@@ -178,9 +178,7 @@ def mla_decode_kvcache_cpu(
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cpu.mla_decode_kvcache(
|
||||
out, query, kv_cache, scale, block_tables, seq_lens
|
||||
)
|
||||
torch.ops._C.mla_decode_kvcache(out, query, kv_cache, scale, block_tables, seq_lens)
|
||||
|
||||
|
||||
# merge attn states ops
|
||||
|
||||
@@ -483,3 +483,27 @@ class CpuPlatform(Platform):
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def import_kernels(cls) -> None:
|
||||
if Platform.get_cpu_architecture() in (CpuArchEnum.X86,):
|
||||
if torch._C._cpu._is_avx512_supported():
|
||||
try:
|
||||
import vllm._C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C: %r", e)
|
||||
else:
|
||||
# Note: The lib name is _C_AVX2, but the module name is _C.
|
||||
# This will cause a exception "dynamic module does define
|
||||
# module export function". But the library is imported
|
||||
# successfully. So ignore the exception for now, until we find
|
||||
# a solution.
|
||||
try:
|
||||
import vllm._C_AVX2 # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C_AVX2: %r", e)
|
||||
else:
|
||||
try:
|
||||
import vllm._C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C: %r", e)
|
||||
|
||||
@@ -85,7 +85,7 @@ class CPUWorker(Worker):
|
||||
self.local_omp_cpuid = omp_cpuids_list[self.rank]
|
||||
|
||||
if self.local_omp_cpuid != "nobind":
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
ret = torch.ops._C.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user