[Kernel] Pass a device pointer into the quantize kernel for the scales (#5159)

This commit is contained in:
Tyler Michael Smith
2024-06-03 12:52:30 -04:00
committed by GitHub
parent 0ab278ca31
commit cbb2f59cc8
5 changed files with 16 additions and 11 deletions

View File

@@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
float scale);
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);