revert: K=V with RoPE + inverse RoPE is the correct DSV4 approach

This commit is contained in:
2026-05-31 04:51:10 +00:00
parent 781ee43521
commit 738088cf49

View File

@@ -353,16 +353,17 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
q_heads = apply_rope_partial(q_heads, positions_dev, rope_cos, rope_sin, hd, rd)
# -- Apply RoPE to KV (at current positions) BEFORE caching --
# DSV4: RoPE applied to K only. V = raw KV output (no RoPE).
# The cache stores K_rope and V_raw separately.
kv_new_k = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd) # RoPE'd K
kv_new_v = kv_new # V without RoPE
# DSV4 convention: RoPE applied to KV before writing to cache.
# K = V in DSV4 MQA (same projection, same RoPE'd tensor).
kv_new = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd)
# -- KV cache: append K (RoPE'd) and V (raw) --
kv_cache.append(kv_new_k.permute(1, 0, 2), kv_new_v.permute(1, 0, 2)) # (1, T, hd)
# -- KV cache: append RoPE'd KV (K=V) --
k_new = kv_new # (T, 1, hd) — RoPE'd
v_new = kv_new # K = V in DSV4 MQA
kv_cache.append(k_new.permute(1, 0, 2), v_new.permute(1, 0, 2)) # (1, T, hd)
# -- Get full KV from cache --
k_full, v_full = kv_cache.get() # (1, seq_len, hd) — K is RoPE'd, V is raw
# -- Get full KV from cache (already RoPE'd) --
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
seq_len = k_full.shape[1]
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
@@ -403,8 +404,8 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
attn_out = dsv4_attention(q_input, k_full, v_full)
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
# -- No inverse RoPE needed (V is un-rotated, only K has RoPE) --
# attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)
# -- Inverse RoPE on attention output (paper §2.3.3) --
attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)
# -- Output projection: wo_a (grouped BMM) + wo_b (NVFP4) --
# wo_a: grouped linear, (n_h, hd) → (n_groups, o_rank) via BMM