[Kernel] Pass a device pointer into the quantize kernel for the scales (#5159)
This commit is contained in:
committed by
GitHub
parent
0ab278ca31
commit
cbb2f59cc8
@@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
float scale);
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale);
|
||||
|
||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
Reference in New Issue
Block a user