diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 6a81f159f..d8369108d 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,7 +1,9 @@ #include #include -#include "../per_token_group_quant_8bit.h" +#ifndef USE_ROCM + #include "../per_token_group_quant_8bit.h" +#endif #include @@ -339,10 +341,12 @@ void dynamic_scaled_int8_quant( }); } +#ifndef USE_ROCM void per_token_group_quant_int8(const torch::Tensor& input, torch::Tensor& output_q, torch::Tensor& output_s, int64_t group_size, double eps, double int8_min, double int8_max) { per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max); -} \ No newline at end of file +} +#endif