From 4373af2e82ab083f1f6f50bd8add978fee7f23e6 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 15 May 2025 16:36:40 +0800 Subject: [PATCH] Add `DG_PRINT_CONFIGS` --- README.md | 2 +- deep_gemm/jit/compiler.py | 14 +++++++------- deep_gemm/jit/runtime.py | 16 ++++++++++++++-- deep_gemm/jit_kernels/gemm.py | 2 +- deep_gemm/jit_kernels/m_grouped_gemm.py | 4 ++-- deep_gemm/jit_kernels/wgrad_gemm.py | 2 +- 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 5f5388b..db89b2d 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ The library also provides some environment variables, which may be useful: - Post optimization - `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default - Heuristic selection - - `DG_PRINT_HEURISTIC`: `0` or `1`, print selected configs for each shape, `0` by default + - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default - Testing - `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 2ab6b25..54e3ab2 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -5,7 +5,7 @@ import re import subprocess import time import uuid -from typing import List, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import cuda.bindings import cuda.bindings.nvrtc as nvrtc @@ -128,7 +128,7 @@ class Compiler: return [get_jit_include_dir()] @classmethod - def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: + def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime: # Compiler flags flags = cls.flags() @@ -140,7 +140,7 @@ class Compiler: # Check runtime cache or file system hit global runtime_cache - cached_runtime = runtime_cache.get(path, runtime_cls) + cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs) if cached_runtime is not None: if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Using cached JIT runtime {name} during build') @@ -166,8 +166,8 @@ class Compiler: os.replace(tmp_cubin_path, cubin_path) # Put cache and return - runtime = runtime_cls(path) - runtime_cache[path] = runtime + runtime = runtime_cache.get(path, runtime_cls, name, kwargs) + assert runtime is not None return runtime @@ -279,6 +279,6 @@ class NVRTCCompiler(Compiler): assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}' -def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime: +def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime: compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler - return compiler_cls.build(name, code, runtime_cls=runtime_cls) + return compiler_cls.build(name, code, runtime_cls, kwargs) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 041e23f..ffcd0b3 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -1,9 +1,11 @@ +import copy import os import subprocess import time +import torch import cuda.bindings.driver as cbd -from typing import List, Optional, Type +from typing import Any, Dict, Optional, Type from torch.utils.cpp_extension import CUDA_HOME @@ -79,13 +81,23 @@ class RuntimeCache: def __setitem__(self, path: str, runtime: Runtime) -> None: self.cache[path] = runtime - def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]: + def get(self, path: str, runtime_cls: Type[Runtime], + name: str = '', kwargs: Dict[str, Any] = None) -> Optional[Runtime]: # In Python runtime if path in self.cache: return self.cache[path] # Already compiled if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path): + # Print heuristic for the first time + if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))): + simplified_kwargs = dict() + for key, value in kwargs.items(): + value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value + value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value + simplified_kwargs[key] = value + print(f'Put kernel {name} with {simplified_kwargs} into runtime cache') + runtime = runtime_cls(path) self.cache[path] = runtime return runtime diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 3adef72..343e84a 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -238,5 +238,5 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Generate, build and run the kernel code = FP8GemmRuntime.generate(**kwargs) - runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime) + runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index ef1a088..73fd2f1 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -102,7 +102,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Generate, build and run the kernel code = FP8GemmRuntime.generate(**kwargs) - runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime) + runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) runtime(**kwargs) @@ -201,5 +201,5 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Generate, build and run the kernel code = FP8GemmRuntime.generate(**kwargs) - runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime) + runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index dea91e2..8a38578 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -111,7 +111,7 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Generate, build and run the kernel code = FP8WGradGemmRuntime.generate(**kwargs) - runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime) + runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs) runtime(**kwargs)