[Kernel] Use flashinfer for decoding (#4353)
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
This commit is contained in:
@@ -96,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache_flash",
|
||||
&reshape_and_cache_flash,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"convert_fp8",
|
||||
&convert_fp8,
|
||||
|
||||
Reference in New Issue
Block a user