- Added --ab-compare flag to run both fused and unfused paths for first 3 layers
- Compares x_normed, gsa values, FP4 data, and GEMM outputs (q_a, kv)
- Added --no-fused-rmsnorm to disable P4 and use unfused path
- This will help diagnose the correctness regression introduced by P4
CRITICAL BUG: The old kernel had __syncthreads() and a spinlock INSIDE
the strided loop over num_valid entries. When num_valid % n_threads != 0
(i.e. essentially always at production context lengths), threads that
exit the loop early deadlock on the barrier while others wait forever.
Fix: per-thread local top-k in registers (LOCAL_K=8), block-level merge
after the loop completes. No in-loop barriers, no spinlocks.
Architecture:
- Each thread maintains a private min-heap of LOCAL_K best scores
- After the strided loop (no __syncthreads inside), threads write their
local top-k to shared memory
- Thread 0 builds the final top-k from all n_threads*LOCAL_K candidates
- For top_k=1024, n_threads=128, LOCAL_K=8: 1024 candidates = exact merge
- SMEM budget: w_h + merge heap + per-thread staging = ~30KB (well under 232KB)
Also updated the copy in dsv4/kernels/cuda/ (the one actually loaded
by the Python bridge).
Future optimization (separate from this fix):
- The dot products are scalar FP32 per thread. At 1M context this is slow.
Production path should use FP4 tcgen05 MMA (Stage F).
- The block-level merge is single-threaded. Could use warp-reduce or
bitonic sort for top_k > 256.
- Use half_step_to_e2m1 for E2M1 FP4 quantization (not LUT search)
- Use __nv_fp8_e4m3 + memcpy for block scale (not reinterpret_cast)
- Pack nibbles as (nibbles[2*i+1] << 4) | nibbles[2*i] (same as prod)
- Output uint8 buffers, then .view() to FP4/FP8 dtypes
- Handle near-zero block scale same as quantize_nvfp4.cu
Root cause: float row_max[n] is a VLA — not allowed in CUDA device code.
Fix: use shared memory with MHC_MAX_N=16 fixed-size slots.
Also: REMOVED the Python fallback in sinkhorn_knopp().
If the CUDA kernel fails, the pipeline DIES. No soft landing.
This is the correct behavior — silent fallback to broken precision
is worse than a loud crash.
The residual growth |X|→500-700 at L60 was likely caused by the Python
fallback running a DIFFERENT numerical path (BF16 accumulation in torch
ops vs FP32 in the CUDA kernel). With the fixed kernel, Sinkhorn should
produce properly doubly-stochastic B_l, bounding the residual.
We tried NVFP4 (Blackwell native FP4→MMA). Three approaches.
cos=0.995 round-trip seems fine in isolation but 4.5 effective bits
compounds fatally across 61 layers of mHC. FP8_E4M3's 5.3 effective
bits gives cos=0.9997 — that 0.4% difference is the margin between
working and broken. Kernels exist, path is proven, precision isn't.
block_reduce_sum/max write to smem[0..n_warps-1] but we passed &s_amax
(single float). For 128 threads / 4 warps, this wrote 4 floats starting
at &s_amax, corrupting adjacent shared variables (s_inv_rms, s_vals).
Fix: use s_scratch[8] array (4 for sum, 4 for max) with proper sizing.
CRITICAL: quantize must use the FP8-round-tripped block scale, not the raw
pre-FP8 value. The dequant reads the FP8 bytes back, so the quantize must
match exactly. Same pattern as quantize_nvfp4.cu. This was the root cause
of cos=0.925 (should be ~0.995).
Previous version used __shfl_down_sync for group-level amax reduction,
but shuffles operate at warp level and crossed group boundaries.
Fix: each thread independently quantizes its assigned 16-element blocks
from shared memory. Simpler and correct.
_apply_rope now uses dsv4.ops.rope_cuda (1 CUDA kernel per call)
instead of PyTorch ops (5-6 kernels per call).
Total: 183 RoPE calls × (5-1) = 732 launches saved per token.
With fallback to PyTorch if CUDA kernel fails.
cute.arch.fmin/fmax take scalar Float32, not TensorSSA.
Replace with cute.where() and arithmetic for TensorSSA compatibility.
Also changed subtile loop to unroll=1 for cute.where() compatibility.