From 99e5539a6701e77b0586e7d5a1fdddf19a1d42f2 Mon Sep 17 00:00:00 2001 From: Artem Perevedentsev Date: Sat, 4 Apr 2026 19:38:02 +0300 Subject: [PATCH] [Perf][GDN] Align TMA usage with upstream FLA (#38981) Signed-off-by: Artem Perevedentsev Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm/model_executor/layers/fla/ops/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py index 83b75e685..d8e7db934 100644 --- a/vllm/model_executor/layers/fla/ops/utils.py +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -154,9 +154,13 @@ is_nvidia_hopper = is_nvidia and ( ) use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" is_gather_supported = hasattr(triton.language, "gather") -is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( - hasattr(triton.language, "_experimental_make_tensor_descriptor") - or hasattr(triton.language, "make_tensor_descriptor") +is_tma_supported = ( + is_nvidia_hopper + and os.getenv("FLA_USE_TMA", "0") == "1" + and ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") + ) )