diff --git a/single_shot_inference.py b/single_shot_inference.py index 8a1c931e..6f244374 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -353,16 +353,17 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, q_heads = apply_rope_partial(q_heads, positions_dev, rope_cos, rope_sin, hd, rd) # -- Apply RoPE to KV (at current positions) BEFORE caching -- - # DSV4: RoPE applied to K only. V = raw KV output (no RoPE). - # The cache stores K_rope and V_raw separately. - kv_new_k = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd) # RoPE'd K - kv_new_v = kv_new # V without RoPE + # DSV4 convention: RoPE applied to KV before writing to cache. + # K = V in DSV4 MQA (same projection, same RoPE'd tensor). + kv_new = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd) - # -- KV cache: append K (RoPE'd) and V (raw) -- - kv_cache.append(kv_new_k.permute(1, 0, 2), kv_new_v.permute(1, 0, 2)) # (1, T, hd) + # -- KV cache: append RoPE'd KV (K=V) -- + k_new = kv_new # (T, 1, hd) — RoPE'd + v_new = kv_new # K = V in DSV4 MQA + kv_cache.append(k_new.permute(1, 0, 2), v_new.permute(1, 0, 2)) # (1, T, hd) - # -- Get full KV from cache -- - k_full, v_full = kv_cache.get() # (1, seq_len, hd) — K is RoPE'd, V is raw + # -- Get full KV from cache (already RoPE'd) -- + k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V seq_len = k_full.shape[1] # -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) -- @@ -403,8 +404,8 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_out = dsv4_attention(q_input, k_full, v_full) attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd) - # -- No inverse RoPE needed (V is un-rotated, only K has RoPE) -- - # attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd) + # -- Inverse RoPE on attention output (paper §2.3.3) -- + attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd) # -- Output projection: wo_a (grouped BMM) + wo_b (NVFP4) -- # wo_a: grouped linear, (n_h, hd) → (n_groups, o_rank) via BMM