[CI/Build] Per file CUDA Archs (improve wheel size and dev build times) (#8845)
This commit is contained in:
@@ -28,6 +28,7 @@
|
||||
|
||||
#include "common/base.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "core/registration.h"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
@@ -1134,3 +1135,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user