[CI/Build] Per file CUDA Archs (improve wheel size and dev build times) (#8845)
This commit is contained in:
@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 90) {
|
||||
// Hopper
|
||||
// Hopper
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
#else
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
} else if (version_num == 89) {
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
||||
} else if (version_num >= 80) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
@@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 90) {
|
||||
// Hopper
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
#else
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
} else if (version_num == 89) {
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
} else if (version_num >= 80) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
} else {
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
Reference in New Issue
Block a user