D1: Add PipelineState for k_sub TMA path
This commit is contained in:
@@ -214,6 +214,8 @@ class FmhaKernel:
|
||||
if warp_idx == self.tma_warp_id:
|
||||
if const_expr(self.n_k_sub_tiles > 1):
|
||||
# K sub-tiling path (hd=512): load Q and K per k_sub using pipeline barriers
|
||||
qp_stage = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.q_stage)
|
||||
kvp_stage = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.kv_stage)
|
||||
for kt in range(self.n_kv_tiles):
|
||||
for k_sub in range(self.n_k_sub_tiles):
|
||||
# Load Q[k_sub] → sQ
|
||||
@@ -224,7 +226,7 @@ class FmhaKernel:
|
||||
kvp.producer_acquire(kvp_stage)
|
||||
cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, kvp_stage.index)], tma_bar_ptr=kvp_stage.barrier)
|
||||
kvp.producer_commit(kvp_stage); kvp_stage.advance()
|
||||
# Load V[kt] → sV
|
||||
# Load V[kt] → sV (uses same K/V pipeline)
|
||||
kvp.producer_acquire(kvp_stage)
|
||||
cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, kvp_stage.index)], tma_bar_ptr=kvp_stage.barrier)
|
||||
kvp.producer_commit(kvp_stage); kvp_stage.advance()
|
||||
|
||||
Reference in New Issue
Block a user