diff --git a/single_shot_inference.py b/single_shot_inference.py index 6f244374..8a1c931e 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -353,17 +353,16 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, q_heads = apply_rope_partial(q_heads, positions_dev, rope_cos, rope_sin, hd, rd) # -- Apply RoPE to KV (at current positions) BEFORE caching -- - # DSV4 convention: RoPE applied to KV before writing to cache. - # K = V in DSV4 MQA (same projection, same RoPE'd tensor). - kv_new = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd) + # DSV4: RoPE applied to K only. V = raw KV output (no RoPE). + # The cache stores K_rope and V_raw separately. + kv_new_k = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd) # RoPE'd K + kv_new_v = kv_new # V without RoPE - # -- KV cache: append RoPE'd KV (K=V) -- - k_new = kv_new # (T, 1, hd) — RoPE'd - v_new = kv_new # K = V in DSV4 MQA - kv_cache.append(k_new.permute(1, 0, 2), v_new.permute(1, 0, 2)) # (1, T, hd) + # -- KV cache: append K (RoPE'd) and V (raw) -- + kv_cache.append(kv_new_k.permute(1, 0, 2), kv_new_v.permute(1, 0, 2)) # (1, T, hd) - # -- Get full KV from cache (already RoPE'd) -- - k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V + # -- Get full KV from cache -- + k_full, v_full = kv_cache.get() # (1, seq_len, hd) — K is RoPE'd, V is raw seq_len = k_full.shape[1] # -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) -- @@ -404,8 +403,8 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_out = dsv4_attention(q_input, k_full, v_full) attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd) - # -- Inverse RoPE on attention output (paper §2.3.3) -- - attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd) + # -- No inverse RoPE needed (V is un-rotated, only K has RoPE) -- + # attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd) # -- Output projection: wo_a (grouped BMM) + wo_b (NVFP4) -- # wo_a: grouped linear, (n_h, hd) → (n_groups, o_rank) via BMM