Memcpy kernel for flash attention (#29)

* optimize

* add benchmark

* add assert

* add test
This commit is contained in:
Siyuan (Ryans) Zhuang
2023-04-10 18:22:49 -07:00
committed by GitHub
parent b9926f7f66
commit e3cec88aa5
4 changed files with 293 additions and 0 deletions

View File

@@ -20,6 +20,13 @@ void reshape_and_cache(
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);
void gather_cached_kv(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"swap_blocks",
@@ -33,4 +40,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
m.def(
"gather_cached_kv",
&gather_cached_kv,
"Gather key and value from the cache into contiguous QKV tensors");
}