revert: K=V with RoPE + inverse RoPE is the correct DSV4 approach
This commit is contained in:
@@ -353,16 +353,17 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
q_heads = apply_rope_partial(q_heads, positions_dev, rope_cos, rope_sin, hd, rd)
|
||||
|
||||
# -- Apply RoPE to KV (at current positions) BEFORE caching --
|
||||
# DSV4: RoPE applied to K only. V = raw KV output (no RoPE).
|
||||
# The cache stores K_rope and V_raw separately.
|
||||
kv_new_k = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd) # RoPE'd K
|
||||
kv_new_v = kv_new # V without RoPE
|
||||
# DSV4 convention: RoPE applied to KV before writing to cache.
|
||||
# K = V in DSV4 MQA (same projection, same RoPE'd tensor).
|
||||
kv_new = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd)
|
||||
|
||||
# -- KV cache: append K (RoPE'd) and V (raw) --
|
||||
kv_cache.append(kv_new_k.permute(1, 0, 2), kv_new_v.permute(1, 0, 2)) # (1, T, hd)
|
||||
# -- KV cache: append RoPE'd KV (K=V) --
|
||||
k_new = kv_new # (T, 1, hd) — RoPE'd
|
||||
v_new = kv_new # K = V in DSV4 MQA
|
||||
kv_cache.append(k_new.permute(1, 0, 2), v_new.permute(1, 0, 2)) # (1, T, hd)
|
||||
|
||||
# -- Get full KV from cache --
|
||||
k_full, v_full = kv_cache.get() # (1, seq_len, hd) — K is RoPE'd, V is raw
|
||||
# -- Get full KV from cache (already RoPE'd) --
|
||||
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
|
||||
seq_len = k_full.shape[1]
|
||||
|
||||
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
|
||||
@@ -403,8 +404,8 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
attn_out = dsv4_attention(q_input, k_full, v_full)
|
||||
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
|
||||
# -- No inverse RoPE needed (V is un-rotated, only K has RoPE) --
|
||||
# attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)
|
||||
# -- Inverse RoPE on attention output (paper §2.3.3) --
|
||||
attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)
|
||||
|
||||
# -- Output projection: wo_a (grouped BMM) + wo_b (NVFP4) --
|
||||
# wo_a: grouped linear, (n_h, hd) → (n_groups, o_rank) via BMM
|
||||
|
||||
Reference in New Issue
Block a user