diff --git a/tests/unit/test_d2_flat_divide_diag.py b/tests/unit/test_d2_flat_divide_diag.py index 61a3e55c..b2a611f3 100644 --- a/tests/unit/test_d2_flat_divide_diag.py +++ b/tests/unit/test_d2_flat_divide_diag.py @@ -10,7 +10,7 @@ Test with: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_flat_divide_d import torch, math, cutlass import cutlass.cute as cute import cutlass.utils as utils -import cutlass.utils.sm100 as sm100 +import cutlass.utils as utils from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass import Float32, BFloat16, Int32, const_expr from cutlass.utils import LayoutEnum @@ -48,11 +48,11 @@ class ShapeDiagKernel: b_major = LayoutEnum.from_tensor(mK).mma_major_mode() v_major = LayoutEnum.from_tensor(mV).mma_major_mode() - qk_mma = sm100.make_trivial_tiled_mma( + qk_mma = utils.sm100.make_trivial_tiled_mma( self.q_dtype, self.q_dtype, a_major, b_major, Float32, self.cta_group, (128, 128), tcgen05.OperandSource.SMEM, ) - pv_mma = sm100.make_trivial_tiled_mma( + pv_mma = utils.sm100.make_trivial_tiled_mma( self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major, Float32, self.cta_group, (128, min(hd, 256)), tcgen05.OperandSource.TMEM, ) @@ -62,24 +62,24 @@ class ShapeDiagKernel: pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik)) - q_smem_s = sm100.make_smem_layout_a(qk_mma, qk_mma_tiler, self.q_dtype, 1) - k_smem_s = sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, self.q_dtype, 1) - v_smem_s = sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, self.q_dtype, 1) + q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, qk_mma_tiler, self.q_dtype, 1) + k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, self.q_dtype, 1) + v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, self.q_dtype, 1) cluster_layout_vmnk = cute.tiled_divide( cute.make_layout((1, 1, 1)), (qk_mma.thr_id.shape,) ) tma_q, tma_mQ = cute.nvgpu.make_tiled_tma_atom_A( - sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), mQ, cute.select(q_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape, ) tma_k, tma_mK = cute.nvgpu.make_tiled_tma_atom_B( - sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), mK, cute.select(k_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape, ) tma_v, tma_mV = cute.nvgpu.make_tiled_tma_atom_B( - sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), mV, cute.select(v_smem_s, mode=[0, 1, 2]), pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape, )