diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py index 83b75e685..d8e7db934 100644 --- a/vllm/model_executor/layers/fla/ops/utils.py +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -154,9 +154,13 @@ is_nvidia_hopper = is_nvidia and ( ) use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" is_gather_supported = hasattr(triton.language, "gather") -is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( - hasattr(triton.language, "_experimental_make_tensor_descriptor") - or hasattr(triton.language, "make_tensor_descriptor") +is_tma_supported = ( + is_nvidia_hopper + and os.getenv("FLA_USE_TMA", "0") == "1" + and ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") + ) )