[BugFix] Fix import error on non-blackwell machines (#21020)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-07-16 01:27:29 -04:00
committed by GitHub
parent 85431bd9ad
commit d31a647124
3 changed files with 12 additions and 16 deletions

View File

@@ -521,15 +521,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor page_table, Tensor workspace, float "
"scale,"
" int num_kv_splits) -> ()");
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
// conditionally compiled so impl in source file
// SM100 CUTLASS MLA workspace
ops.def(
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
" int sm_count, int num_kv_splits) "
"-> int");
ops.impl("sm100_cutlass_mla_get_workspace_size",
&sm100_cutlass_mla_get_workspace_size);
// conditionally compiled so impl in source file
// Compute NVFP4 block quantized tensor.
ops.def(