Memcpy kernel for flash attention (#29)
* optimize * add benchmark * add assert * add test
This commit is contained in:
committed by
GitHub
parent
b9926f7f66
commit
e3cec88aa5
@@ -20,6 +20,13 @@ void reshape_and_cache(
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping);
|
||||
|
||||
void gather_cached_kv(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"swap_blocks",
|
||||
@@ -33,4 +40,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
m.def(
|
||||
"gather_cached_kv",
|
||||
&gather_cached_kv,
|
||||
"Gather key and value from the cache into contiguous QKV tensors");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user