[Perf][GDN] Align TMA usage with upstream FLA (#38981)

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Artem Perevedentsev
2026-04-04 19:38:02 +03:00
committed by GitHub
parent a88ce94bbb
commit 99e5539a67

View File

@@ -154,10 +154,14 @@ is_nvidia_hopper = is_nvidia and (
) )
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
is_gather_supported = hasattr(triton.language, "gather") is_gather_supported = hasattr(triton.language, "gather")
is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( is_tma_supported = (
is_nvidia_hopper
and os.getenv("FLA_USE_TMA", "0") == "1"
and (
hasattr(triton.language, "_experimental_make_tensor_descriptor") hasattr(triton.language, "_experimental_make_tensor_descriptor")
or hasattr(triton.language, "make_tensor_descriptor") or hasattr(triton.language, "make_tensor_descriptor")
) )
)
def get_all_max_shared_mem(): def get_all_max_shared_mem():