- F.linear(x, W) computes x @ W.T which caused shape mismatch when
W_gate was pre-transposed to [E, H]
- Use torch.matmul(x, W_gate) instead — computes x @ W directly, no
transpose needed, no FP32 conversion, fully graph-capturable
- W_gate stays as [H, E] (original checkpoint shape)
- Run GEMM in BF16 (not FP32) during graph capture — Blackwell tensor cores
handle BF16 natively; FP32 GEMM triggers cudaErrorStreamCaptureUnsupported
- Pre-transpose W_gate to [E, H] at load time — avoids .T view during capture
- Convert only logits output to FP32 for sqrt(softplus) numerical stability
- This fixes the graph capture failure at layer 0 Graph B
hidden_states.float() and gate_bf16.T.float() create new FP32 tensors
during CUDA graph capture, which is not graph-capturable.
Fix: run the linear in BF16 (Blackwell tensor cores handle BF16 natively),
then convert only the output logits to FP32 for numerical stability
in sqrt(softplus). The single logits.float() is graph-capturable
because it's a unary op with a pre-existing output buffer.
Python view operations (reshape, transpose, permute) are not
graph-capturable — they cause cudaErrorStreamCaptureUnsupported.
Added:
- dsv4/kernels/cuda/blackwell_swizzle.cu: custom CUDA kernel for 32_4_4 swizzle
- to_blocked(): detects graph capture, uses CUDA kernel instead of Python views
- MoE _assemble_scales_cudagraph_safe: same treatment
- Shared expert _assemble_scales_single_group: same treatment
- Linear _assemble_scales_single_group: same treatment
- Pre-allocated swizzled output buffers for all layers (avoids torch.empty_like)
The CUDA kernel writes to a pre-allocated buffer — no per-step allocations.
Eager path unchanged (still uses fast Python view operations).
The positional bias (ape/B) should only modulate the compression
softmax logits (Z + B), NOT be added to the KV content itself.
Paper equation: compressed = softmax(Z + B) · C
Bug was doing: compressed = softmax(Z + B) · (C + B) — poisons every
compressed KV entry with learned positional-bias content.
Fixed in both CSA (compress_csa_reduce_kernel) and HCA
(hca_compress_reduce_kernel) paths in compressor_reduce.cu.
The kernel was using head strides for the T (query row) dimension,
which happened to work for T=1 (qr=0 always) but was wrong for T>1.
For (B,H,T,NOPE) layout:
- Head stride = T*NOPE, but T stride = NOPE
- Scale head stride = T, but T stride = 1
- RoPE head stride = T*ROPE, but T stride = ROPE
Added q_nope_t_stride, q_scale_t_stride, q_rope_t_stride to params
struct, C API, and Python wrapper.
Replace complex n_sub-iterating read with the same HD/8 iteration
pattern as the proven decode kernel. Extract from lane qr%32 instead
of always lane 0. For qr>=32, use warp 1; for qr>=64, add TMEM offset.
This should fix the row 1 accuracy issue (was cos=0.94 vs decode).
prefill_read_qk_rows was reading from address 0 (sg_off + n * 8)
instead of tb + sg_off + n * 8. This caused garbage QK values,
explaining the 0.928 cosine for T=1 and NaN for T>1.
Two critical fixes:
1. prefill_read_pv_all_subs: was missing 'tb' base in TMEM read address
2. PV MMA ACCUMULATE: use pv_kt == 0 (not kv_tile==0 && pv_kt==0 && n_sub==0)
so each query row's PV starts fresh instead of accumulating into previous row's result
Critical bug: prefill_read_pv_row only read n_sub=0 (16 out of 512 HD dims).
Replaced with prefill_read_pv_all_subs that iterates over all 32 n_sub groups.
Also fixed TMEM row-group/warp mapping for rows 32-127.
Key insight: tcgen05.ld.32x32b.x8 maps warp 0 to rows 0-31 and warp 1 to
rows 32-63 from the SAME TMEM address. The hardware routes row slices
based on warp position in the warpgroup.
Fix approach (from external LLM review):
- Warps 0-1 both read from tb + col_base (same address)
- Each warp writes partial scores to its own sWarpScores partition
- After __syncthreads(), merge both partitions for final 64-head scores
- No race conditions, no cross-warp accumulation bugs
Revert to 16x256b.x1 approach (reads 64 rows from single column).
Previous hang was likely due to TMEM_COLS=128 (too small).
With TMEM_COLS=512, the full 128-row MMA output fits in TMEM.
Lane i reads rows 4i..4i+3. Lanes 0-15 cover rows 0-63.
4 warps (0-3) each process 32 columns, computing weighted ReLU scores.
ROOT CAUSE: The TMEM read for rows 32-63 was wrong. The 32x32b.x8
instruction reads 32 rows per warp. Per P7 docs, warp 0 sees rows 0-31
and warp 1 sees rows 32-63 from the SAME TMEM address. There is no TMEM
offset for different row groups — the row-to-lane mapping depends on
the warp ID.
Fix: warp 0 reads heads 0-31, warp 1 reads heads 32-63 from tb + col_base.
Cross-warp reduce via SMEM to compute full 64-head weighted ReLU scores.
The MMA produces 128 rows × 128 cols = 4 row-groups × 128 TMEM cols = 512 total.
Even though we only read rows 0-63, the MMA writes all 128 rows.
TMEM_COLS must match the MMA output size, not just the read size.
ROOT CAUSES:
1. tcgen05.ld.16x256b.x1 was hanging — either invalid instruction or unaligned
2. TMEM_COLS=128 was too small for 64-row MMA output (needs 256 for 2 row-groups)
3. TMEM row-group addressing: rows 32-63 are at offset SK_TILE (128) in TMEM
Fixes:
- Use tcgen05.ld.32x32b.x8 (proven in B1 FMHA) instead of 16x256b.x1
- Increase TMEM_COLS from 128 to 256
- Read both row-groups (0-31 and 32-63) per 8-column chunk
- Each lane handles head i (from row-group 0) and head 32+i (from row-group 1)
- Warp-level reduce sums contributions from all 64 heads per column
ROOT CAUSE: canon_idx_bf16_16x16(kk, dd) was swapping the outer/inner group
structure compared to the working TMA-loaded V layout in the multitile kernel.
Working layout: (lr/8)*128 + (dd/8)*64 + (dd%8)*8 + (lr%8)
B1 with (kk,dd): (dd/8)*128 + (kk/8)*64 + (kk%8)*8 + (dd%8) <- WRONG
B1 with (dd,kk): (kk/8)*128 + (dd/8)*64 + (dd%8)*8 + (kk%8) <- CORRECT
This caused the V matrix to be loaded into SMEM with transposed group
structure, producing garbage output (cos=0.158 vs BF16 reference).
- New kernel: dsv4/kernels/cuda/indexer_fp8_score_topk.cu
- Native Blackwell FP8 GEMM via tcgen05.mma.kind::f8f6f4
- Q (n_ih=64, ihd=128) quantized BF16→FP8, K consumed directly as FP8_E4M3
- TMEM read using 16x256b.x1 (4-warps parallel, proven from B1 FMHA)
- On-the-fly: dequant (q_scale*k_scale) → ReLU → weighted sum → top-k
- No global BF16 staging of indexer keys, no FP32 einsum on CUDA cores
- Per-thread register heap top-k (same algorithm as indexer_score_topk.cu)
- Modified: single_shot_inference.py
- Indexer.forward() now takes kv_cache directly (not comp_idx_kv BF16)
- Consumes FP8 indexer keys from cache without BF16 dequantization
- Dispatches to B2 FP8 kernel for T=1, n_ih=64, ihd=128 (production decode)
- FP32 einsum fallback retained only for T>1 (prefill)
- Removed 'Intentional first-pass limits' section from B1 doc
(those limits ARE the correct production design, not shortcuts)
torch.utils.cpp_extension.load creates a 'lock' file in the build
directory during compilation. If the compiling process is killed
(OOM, timeout, user interrupt), the lock file is never removed and
subsequent processes spin forever polling it (clock_nanosleep(100ms)
→ stat(lock) → repeat).
Fix: _cleanup_stale_lock() removes lock files older than 10 minutes
before any compilation attempt. This is the correct threshold — CUDA
kernel compilation should never take more than a few minutes, so a
10-minute-old lock is guaranteed stale.
1. score_topk.py: Fix docstring — K^IComp[s] is shared (MQA), not per-head K^IComp[s,h]
Matches the .cu kernel and production Indexer.forward() einsum.
2. score_topk.py: Add WARNING about valid_lens broadcast being wrong for batched prefill
3. csa_indexer.py: Replace random weights with RuntimeError — CSAIndexer has no
checkpoint loading. Production uses the Indexer class in single_shot_inference.py.
4. csa_indexer.py: Document RoPE assumption — indexer queries/keys have no RoPE.
NEEDS VERIFICATION against HF reference.
The CUDA loader (dsv4/kernels/cuda/loader.py) resolves all .cu
files relative to dsv4/kernels/cuda/. The indexer/ subfolder copies
were never loaded — they were dead code that could silently diverge
from the canonical copies in cuda/.
CRITICAL BUG: The old kernel had __syncthreads() and a spinlock INSIDE
the strided loop over num_valid entries. When num_valid % n_threads != 0
(i.e. essentially always at production context lengths), threads that
exit the loop early deadlock on the barrier while others wait forever.
Fix: per-thread local top-k in registers (LOCAL_K=8), block-level merge
after the loop completes. No in-loop barriers, no spinlocks.
Architecture:
- Each thread maintains a private min-heap of LOCAL_K best scores
- After the strided loop (no __syncthreads inside), threads write their
local top-k to shared memory
- Thread 0 builds the final top-k from all n_threads*LOCAL_K candidates
- For top_k=1024, n_threads=128, LOCAL_K=8: 1024 candidates = exact merge
- SMEM budget: w_h + merge heap + per-thread staging = ~30KB (well under 232KB)
Also updated the copy in dsv4/kernels/cuda/ (the one actually loaded
by the Python bridge).
Future optimization (separate from this fix):
- The dot products are scalar FP32 per thread. At 1M context this is slow.
Production path should use FP4 tcgen05 MMA (Stage F).
- The block-level merge is single-threaded. Could use warp-reduce or
bitonic sort for top_k > 256.
- Use half_step_to_e2m1 for E2M1 FP4 quantization (not LUT search)
- Use __nv_fp8_e4m3 + memcpy for block scale (not reinterpret_cast)
- Pack nibbles as (nibbles[2*i+1] << 4) | nibbles[2*i] (same as prod)
- Output uint8 buffers, then .view() to FP4/FP8 dtypes
- Handle near-zero block scale same as quantize_nvfp4.cu
Root cause: float row_max[n] is a VLA — not allowed in CUDA device code.
Fix: use shared memory with MHC_MAX_N=16 fixed-size slots.
Also: REMOVED the Python fallback in sinkhorn_knopp().
If the CUDA kernel fails, the pipeline DIES. No soft landing.
This is the correct behavior — silent fallback to broken precision
is worse than a loud crash.
The residual growth |X|→500-700 at L60 was likely caused by the Python
fallback running a DIFFERENT numerical path (BF16 accumulation in torch
ops vs FP32 in the CUDA kernel). With the fixed kernel, Sinkhorn should
produce properly doubly-stochastic B_l, bounding the residual.
block_reduce_sum/max write to smem[0..n_warps-1] but we passed &s_amax
(single float). For 128 threads / 4 warps, this wrote 4 floats starting
at &s_amax, corrupting adjacent shared variables (s_inv_rms, s_vals).
Fix: use s_scratch[8] array (4 for sum, 4 for max) with proper sizing.
CRITICAL: quantize must use the FP8-round-tripped block scale, not the raw
pre-FP8 value. The dequant reads the FP8 bytes back, so the quantize must
match exactly. Same pattern as quantize_nvfp4.cu. This was the root cause
of cos=0.925 (should be ~0.995).
Previous version used __shfl_down_sync for group-level amax reduction,
but shuffles operate at warp level and crossed group boundaries.
Fix: each thread independently quantizes its assigned 16-element blocks
from shared memory. Simpler and correct.
cute.arch.fmin/fmax take scalar Float32, not TensorSSA.
Replace with cute.where() and arithmetic for TensorSSA compatibility.
Also changed subtile loop to unroll=1 for cute.where() compatibility.
The __call__ method passes these 3 Optional params to self.kernel(),
but kernel() didn't accept them, causing TypeError: too many positional
arguments during cute.compile(). This was the CuTeDSL 'arg-binding bug'
blocking P0/P1.