From 391755ada0ffefa9a6a52b6f14dcaf22d1a463e0 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 16 May 2025 14:39:58 +0800 Subject: [PATCH] Fix JIT tests --- deep_gemm/jit/compiler.py | 2 +- deep_gemm/jit/runtime.py | 10 ++++++---- deep_gemm/jit_kernels/gemm.py | 2 +- deep_gemm/jit_kernels/m_grouped_gemm.py | 4 ++-- deep_gemm/jit_kernels/wgrad_gemm.py | 2 +- tests/test_jit.py | 5 +++-- 6 files changed, 14 insertions(+), 11 deletions(-) diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 54e3ab2..6401377 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -166,7 +166,7 @@ class Compiler: os.replace(tmp_cubin_path, cubin_path) # Put cache and return - runtime = runtime_cache.get(path, runtime_cls, name, kwargs) + runtime = runtime_cache.get(path, runtime_cls, name, kwargs, force_enable_cache=True) assert runtime is not None return runtime diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 3646d26..7a63bf1 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -33,7 +33,7 @@ class Runtime: def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: raise NotImplemented - def __call__(self, kwargs: Dict[str, Any]) -> cbd.CUresult: + def __call__(self, **kwargs) -> cbd.CUresult: # Load CUBIN if self.kernel is None: start_time = time.time_ns() @@ -81,17 +81,19 @@ class RuntimeCache: self.cache[path] = runtime def get(self, path: str, runtime_cls: Type[Runtime], - name: str = '', kwargs: Dict[str, Any] = None) -> Optional[Runtime]: + name: str = '', kwargs: Dict[str, Any] = None, + force_enable_cache: bool = False) -> Optional[Runtime]: # In Python runtime if path in self.cache: return self.cache[path] # Already compiled - if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path): + use_cache = force_enable_cache or not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) + if use_cache and os.path.exists(path) and Runtime.is_path_valid(path): # Print heuristic for the first time if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))): simplified_kwargs = dict() - for key, value in kwargs.items(): + for key, value in kwargs.items() if kwargs is not None else dict().items(): value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value simplified_kwargs[key] = value diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 9cb01c3..574f821 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -239,4 +239,4 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Generate, build and run the kernel code = FP8GemmRuntime.generate(kwargs) runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(kwargs) + runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index b072060..ca2fc79 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -103,7 +103,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Generate, build and run the kernel code = FP8GemmRuntime.generate(kwargs) runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(kwargs) + runtime(**kwargs) def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -202,4 +202,4 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Generate, build and run the kernel code = FP8GemmRuntime.generate(kwargs) runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(kwargs) + runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index d0655cc..00b8cd1 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -110,7 +110,7 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Generate, build and run the kernel code = FP8WGradGemmRuntime.generate(kwargs) runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs) - runtime(kwargs) + runtime(**kwargs) def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], diff --git a/tests/test_jit.py b/tests/test_jit.py index 413bd01..26b7b36 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -74,7 +74,8 @@ static void __instantiate_kernel() {{ if __name__ == '__main__': print('Generated code:') - code = VectorAddRuntime.generate(T='float') + kwargs = {'T': 'float'} + code = VectorAddRuntime.generate(kwargs) print(code) print() @@ -85,7 +86,7 @@ if __name__ == '__main__': # Build print('Building ...') - func = compiler_cls.build('test_func', code, VectorAddRuntime) + func = compiler_cls.build('test_func', code, VectorAddRuntime, kwargs) # Run and check a = torch.randn((1024, ), dtype=torch.float32, device='cuda')