fix: use utils.sm100 instead of sm100 in diagnostic
This commit is contained in:
@@ -10,7 +10,7 @@ Test with: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_flat_divide_d
|
||||
import torch, math, cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.utils.sm100 as sm100
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
@@ -48,11 +48,11 @@ class ShapeDiagKernel:
|
||||
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
|
||||
|
||||
qk_mma = sm100.make_trivial_tiled_mma(
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype, a_major, b_major,
|
||||
Float32, self.cta_group, (128, 128), tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
pv_mma = sm100.make_trivial_tiled_mma(
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major,
|
||||
Float32, self.cta_group, (128, min(hd, 256)), tcgen05.OperandSource.TMEM,
|
||||
)
|
||||
@@ -62,24 +62,24 @@ class ShapeDiagKernel:
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
|
||||
q_smem_s = sm100.make_smem_layout_a(qk_mma, qk_mma_tiler, self.q_dtype, 1)
|
||||
k_smem_s = sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, self.q_dtype, 1)
|
||||
v_smem_s = sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, self.q_dtype, 1)
|
||||
q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, qk_mma_tiler, self.q_dtype, 1)
|
||||
k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, self.q_dtype, 1)
|
||||
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, self.q_dtype, 1)
|
||||
|
||||
cluster_layout_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((1, 1, 1)), (qk_mma.thr_id.shape,)
|
||||
)
|
||||
|
||||
tma_q, tma_mQ = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
mQ, cute.select(q_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape,
|
||||
)
|
||||
tma_k, tma_mK = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
mK, cute.select(k_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape,
|
||||
)
|
||||
tma_v, tma_mV = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
mV, cute.select(v_smem_s, mode=[0, 1, 2]), pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user