Add batched RoPE kernel (#3095)

This commit is contained in:
Terry
2024-03-13 13:45:26 -07:00
committed by GitHub
parent ae0ccb4017
commit 7e9bd08f60
6 changed files with 421 additions and 41 deletions

View File

@@ -56,6 +56,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def(
"batched_rotary_embedding",
&batched_rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
// Quantization ops
#ifndef USE_ROCM
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");