fix: properly scope swap code inside else/guard blocks, replace continue with if guard

This commit is contained in:
2026-05-31 23:51:43 +00:00
parent d0d765e1f2
commit 1bc0da0f35

View File

@@ -383,9 +383,9 @@ class DenseRouterDecodeKernel:
_done = cutlass.Bool(True)
if not _done:
ts = hs[root]; ti = hi[root]; ta = ha[root]
hs[root] = hs[smallest]; hi[root] = hi[smallest]; ha[root] = ha[smallest]
hs[smallest] = ts; hi[smallest] = ti; ha[smallest] = ta
root = smallest
hs[root] = hs[smallest]; hi[root] = hi[smallest]; ha[root] = ha[smallest]
hs[smallest] = ts; hi[smallest] = ti; ha[smallest] = ta
root = smallest
# Write heap to shared memory for merge
tid = (warp_idx * 32 + tidx)
@@ -407,8 +407,8 @@ class DenseRouterDecodeKernel:
cs = storage.heap_scores.data_ptr()[t*6+i]
ci = storage.heap_indices.data_ptr()[t*6+i]
ca = storage.heap_acts.data_ptr()[t*6+i]
if ci < 0: continue
if cs > fs[0] or (cs == fs[0] and ci < fi[0]):
if ci >= 0:
if cs > fs[0] or (cs == fs[0] and ci < fi[0]):
fs[0] = cs; fi[0] = ci; fa[0] = ca
# Sift down
r = 0
@@ -425,9 +425,9 @@ class DenseRouterDecodeKernel:
_done2 = cutlass.Bool(True)
else:
ts=fs[r]; ti=fi[r]; ta=fa[r]
fs[r]=fs[sm]; fi[r]=fi[sm]; fa[r]=fa[sm]
fs[sm]=ts; fi[sm]=ti; fa[sm]=ta
r = sm
fs[r]=fs[sm]; fi[r]=fi[sm]; fa[r]=fa[sm]
fs[sm]=ts; fi[sm]=ti; fa[sm]=ta
r = sm
# Sort descending (selection sort, k=6)
sorted_s = [cutlass.Float32(-1e30)]*6