[BugFix] Fix import error on non-blackwell machines (#21020)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-07-16 01:27:29 -04:00
committed by GitHub
parent 85431bd9ad
commit d31a647124
3 changed files with 12 additions and 16 deletions

View File

@@ -18,6 +18,7 @@ limitations under the License.
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
#include "core/registration.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
@@ -270,4 +271,13 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
}
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) {
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
}
// clang-format on