diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index c73ffcdc..95d9a4fc 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -212,7 +212,7 @@ class FmhaV3StageCMulti: cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset(); pk = kvp.try_acquire() - for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1): + for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): kvh = kvp.acquire_and_advance(pk) cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) @@ -228,7 +228,7 @@ class FmhaV3StageCMulti: kvc.reset(); pk = kvc.try_wait() acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) acc_pipe.producer_acquire(acc_st) - for kt in range(n_kv_tiles): + for kt in range(self.n_kv_tiles): kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) sh = s_prod.acquire_and_advance() qk_mma.set(tcgen05.Field.ACCUMULATE, False)