[CI/Build] Per file CUDA Archs (improve wheel size and dev build times) (#8845)
This commit is contained in:
@@ -26,6 +26,7 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "common/base.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include "common/mem.h"
|
||||
@@ -1066,3 +1067,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("marlin_gemm", &marlin_gemm);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user