diff --git a/dsv4/kernels/router/dense_router_decode_kernel.py b/dsv4/kernels/router/dense_router_decode_kernel.py index 2046875a..28e6ee37 100644 --- a/dsv4/kernels/router/dense_router_decode_kernel.py +++ b/dsv4/kernels/router/dense_router_decode_kernel.py @@ -383,9 +383,9 @@ class DenseRouterDecodeKernel: _done = cutlass.Bool(True) if not _done: ts = hs[root]; ti = hi[root]; ta = ha[root] - hs[root] = hs[smallest]; hi[root] = hi[smallest]; ha[root] = ha[smallest] - hs[smallest] = ts; hi[smallest] = ti; ha[smallest] = ta - root = smallest + hs[root] = hs[smallest]; hi[root] = hi[smallest]; ha[root] = ha[smallest] + hs[smallest] = ts; hi[smallest] = ti; ha[smallest] = ta + root = smallest # Write heap to shared memory for merge tid = (warp_idx * 32 + tidx) @@ -407,8 +407,8 @@ class DenseRouterDecodeKernel: cs = storage.heap_scores.data_ptr()[t*6+i] ci = storage.heap_indices.data_ptr()[t*6+i] ca = storage.heap_acts.data_ptr()[t*6+i] - if ci < 0: continue - if cs > fs[0] or (cs == fs[0] and ci < fi[0]): + if ci >= 0: + if cs > fs[0] or (cs == fs[0] and ci < fi[0]): fs[0] = cs; fi[0] = ci; fa[0] = ca # Sift down r = 0 @@ -425,9 +425,9 @@ class DenseRouterDecodeKernel: _done2 = cutlass.Bool(True) else: ts=fs[r]; ti=fi[r]; ta=fa[r] - fs[r]=fs[sm]; fi[r]=fi[sm]; fa[r]=fa[sm] - fs[sm]=ts; fi[sm]=ti; fa[sm]=ta - r = sm + fs[r]=fs[sm]; fi[r]=fi[sm]; fa[r]=fa[sm] + fs[sm]=ts; fi[sm]=ti; fa[sm]=ta + r = sm # Sort descending (selection sort, k=6) sorted_s = [cutlass.Float32(-1e30)]*6