Fix diagnostic test: same Int32(kt) + n_kv_tiles fixes
This commit is contained in:
@@ -21,6 +21,7 @@ class FmhaV3Diag:
|
||||
If this fails, the problem is in TMA/MMA, not softmax."""
|
||||
def __init__(self, s_k=256):
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
@@ -127,7 +128,7 @@ class FmhaV3Diag:
|
||||
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
|
||||
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
|
||||
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
|
||||
n_kv_tiles = cute.size(gK, mode=[3])
|
||||
n_kv_tiles = self.n_kv_tiles
|
||||
|
||||
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
|
||||
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
|
||||
@@ -137,7 +138,7 @@ class FmhaV3Diag:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,0,None,0)]
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -160,16 +161,17 @@ class FmhaV3Diag:
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# TMA LOAD (combined K+V barrier)
|
||||
# TMA LOAD (combined K+V barrier, Int32(kt) for GMEM coord)
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset(); pk = kvp.try_acquire()
|
||||
for kt in cutlass.range(n_kv_tiles, unroll=1):
|
||||
for kt in cutlass.range(self.n_kv_tiles, unroll=1):
|
||||
coord = Int32(kt)
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_k, tBgK[(None, coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
@@ -180,7 +182,7 @@ class FmhaV3Diag:
|
||||
kvc.reset(); pk = kvc.try_wait()
|
||||
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_st)
|
||||
for kt in range(n_kv_tiles):
|
||||
for kt in range(self.n_kv_tiles):
|
||||
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
|
||||
sh = s_prod.acquire_and_advance()
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
@@ -227,7 +229,7 @@ class FmhaV3Diag:
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
for kt in range(n_kv_tiles):
|
||||
for kt in range(self.n_kv_tiles):
|
||||
si_handle = s_cons.wait_and_advance()
|
||||
|
||||
# Load S from TMEM
|
||||
|
||||
Reference in New Issue
Block a user