fix: use utils.sm100 instead of sm100 in diagnostic

This commit is contained in:
2026-05-25 02:34:25 +00:00
parent 7599801f57
commit 684e9a85fe

View File

@@ -10,7 +10,7 @@ Test with: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_flat_divide_d
import torch, math, cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.utils.sm100 as sm100
import cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
@@ -48,11 +48,11 @@ class ShapeDiagKernel:
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
qk_mma = sm100.make_trivial_tiled_mma(
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, a_major, b_major,
Float32, self.cta_group, (128, 128), tcgen05.OperandSource.SMEM,
)
pv_mma = sm100.make_trivial_tiled_mma(
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major,
Float32, self.cta_group, (128, min(hd, 256)), tcgen05.OperandSource.TMEM,
)
@@ -62,24 +62,24 @@ class ShapeDiagKernel:
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
q_smem_s = sm100.make_smem_layout_a(qk_mma, qk_mma_tiler, self.q_dtype, 1)
k_smem_s = sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, self.q_dtype, 1)
v_smem_s = sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, self.q_dtype, 1)
q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, qk_mma_tiler, self.q_dtype, 1)
k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, self.q_dtype, 1)
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, self.q_dtype, 1)
cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (qk_mma.thr_id.shape,)
)
tma_q, tma_mQ = cute.nvgpu.make_tiled_tma_atom_A(
sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
mQ, cute.select(q_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape,
)
tma_k, tma_mK = cute.nvgpu.make_tiled_tma_atom_B(
sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
mK, cute.select(k_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape,
)
tma_v, tma_mV = cute.nvgpu.make_tiled_tma_atom_B(
sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
mV, cute.select(v_smem_s, mode=[0, 1, 2]), pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape,
)