[BugFix] Fix import error on non-blackwell machines (#21020)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -521,15 +521,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor page_table, Tensor workspace, float "
|
||||
"scale,"
|
||||
" int num_kv_splits) -> ()");
|
||||
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// SM100 CUTLASS MLA workspace
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
|
||||
" int sm_count, int num_kv_splits) "
|
||||
"-> int");
|
||||
ops.impl("sm100_cutlass_mla_get_workspace_size",
|
||||
&sm100_cutlass_mla_get_workspace_size);
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
|
||||
Reference in New Issue
Block a user