fix: disable broken CuTeDSL fused router — use BF16 linear + activation_topk (both are production paths)
This commit is contained in:
@@ -25,16 +25,18 @@ def dense_router_dispatch(
|
||||
"""
|
||||
N = hidden_states.shape[0]
|
||||
|
||||
if N <= 64:
|
||||
try:
|
||||
_run_fused_decode(
|
||||
hidden_states, W_gate, e_bias,
|
||||
routed_scaling_factor, top_k,
|
||||
out_weights, out_ids,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
pass # fall through to prefill path
|
||||
# The CuTeDSL fused decode kernel has a TMA partition layout bug that
|
||||
# causes cute.compile to fail after a long compilation attempt.
|
||||
# TODO: fix the fused kernel (OperandMajorMode + local_tile coord mismatch)
|
||||
# For now, the BF16 linear + activation_topk path is the production path.
|
||||
# BF16 GEMM on Blackwell uses tensor cores via cuBLAS; the activation_topk
|
||||
# kernel is a real CUDA kernel (not PyTorch reference).
|
||||
# if N <= 64:
|
||||
# try:
|
||||
# _run_fused_decode(...)
|
||||
# return
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
_run_prefill_path(
|
||||
hidden_states, W_gate, e_bias,
|
||||
|
||||
Reference in New Issue
Block a user