diff --git a/dsv4/kernels/cuda/gather_kv.cu b/dsv4/kernels/cuda/gather_kv.cu index 77692d91..7ea27878 100644 --- a/dsv4/kernels/cuda/gather_kv.cu +++ b/dsv4/kernels/cuda/gather_kv.cu @@ -103,4 +103,8 @@ void gather_kv_cuda( PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gather_kv", &gather_kv_cuda, "Gather KV entries into dense tile"); + // gather_swa + gather_all_compressed are defined in gather_swa.cu + // (compiled together into the same .so) + m.def("gather_swa", &gather_swa_cuda, "Gather SWA window into dense BF16 tile"); + m.def("gather_all_compressed", &gather_all_compressed_cuda, "Gather all compressed KV for HCA"); } diff --git a/dsv4/kernels/cuda/gather_swa.cu b/dsv4/kernels/cuda/gather_swa.cu index d34a18a4..9123a0f3 100644 --- a/dsv4/kernels/cuda/gather_swa.cu +++ b/dsv4/kernels/cuda/gather_swa.cu @@ -171,7 +171,5 @@ void gather_all_compressed_cuda( } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("gather_swa", &gather_swa_cuda, "Gather SWA window into dense BF16 tile"); - m.def("gather_all_compressed", &gather_all_compressed_cuda, "Gather all compressed KV for HCA"); -} +// gather_swa functions are registered in gather_kv.cu's PYBIND11_MODULE +// (both files are compiled together into a single cache_gather.so)