fix: disable broken CuTeDSL fused router — use BF16 linear + activation_topk (both are production paths)

This commit is contained in:
2026-06-01 00:56:00 +00:00
parent c339fe7ad9
commit 0ab5d8c317

View File

@@ -25,16 +25,18 @@ def dense_router_dispatch(
"""
N = hidden_states.shape[0]
if N <= 64:
try:
_run_fused_decode(
hidden_states, W_gate, e_bias,
routed_scaling_factor, top_k,
out_weights, out_ids,
)
return
except Exception:
pass # fall through to prefill path
# The CuTeDSL fused decode kernel has a TMA partition layout bug that
# causes cute.compile to fail after a long compilation attempt.
# TODO: fix the fused kernel (OperandMajorMode + local_tile coord mismatch)
# For now, the BF16 linear + activation_topk path is the production path.
# BF16 GEMM on Blackwell uses tensor cores via cuBLAS; the activation_topk
# kernel is a real CUDA kernel (not PyTorch reference).
# if N <= 64:
# try:
# _run_fused_decode(...)
# return
# except Exception:
# pass
_run_prefill_path(
hidden_states, W_gate, e_bias,