CRITICAL FIX: Add YaRN RoPE scaling (factor=16)
The DSV4 Pro model uses rope_type='yarn' with factor=16. Our build_rope_cache was using standard RoPE with theta=10000, completely ignoring YaRN scaling. This produced wrong cos/sin values for all positions, causing incorrect attention scores and garbage output. YaRN modifies the RoPE frequencies: - High-frequency components: unchanged - Low-frequency components: scaled by 1/factor - Medium: smooth interpolation Config: factor=16, beta_fast=32, beta_slow=1, orig_max_pos=65536
This commit is contained in:
@@ -189,17 +189,47 @@ class mHCBlock:
|
||||
# RoPE — partial, GPT-J interleaved, last rope_dim dims
|
||||
# =====================================================================
|
||||
|
||||
def build_rope_cache(max_pos, rope_dim, device, theta=10000.0):
|
||||
def build_rope_cache(max_pos, rope_dim, device, theta=10000.0,
|
||||
rope_type="default", rope_factor=1.0,
|
||||
original_max_pos=4096, beta_fast=32, beta_slow=1):
|
||||
"""Build cos/sin caches for partial RoPE.
|
||||
|
||||
CRITICAL: FP32, not BF16! BF16 quantization destroys cos²+sin²=1
|
||||
identity needed for inverse RoPE. BF16 cos²+sin² can be 0.996,
|
||||
causing ~3% round-trip error that accumulates across 61 layers.
|
||||
|
||||
Supports YaRN (Yet another RoPE extensioN) scaling for long context.
|
||||
The DSV4 Pro model uses rope_type='yarn' with factor=16.
|
||||
|
||||
Returns: (cos_cache, sin_cache) each (max_pos, rope_dim//2) FP32
|
||||
"""
|
||||
half = rope_dim // 2
|
||||
# Base frequencies: 1 / theta^(2i/d)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
||||
|
||||
if rope_type == "yarn" and rope_factor > 1.0:
|
||||
# YaRN frequency scaling
|
||||
# Compute wavelength thresholds
|
||||
low_freq_wavelen = original_max_pos / (beta_fast * 2.0) # High-freq cutoff
|
||||
high_freq_wavelen = original_max_pos / (beta_slow * 2.0) # Low-freq cutoff
|
||||
|
||||
new_freqs = []
|
||||
for freq in freqs:
|
||||
wavelen = 2 * math.pi / freq
|
||||
if wavelen < low_freq_wavelen:
|
||||
# High frequency: no scaling
|
||||
new_freqs.append(freq)
|
||||
elif wavelen > high_freq_wavelen:
|
||||
# Low frequency: scale by 1/factor
|
||||
new_freqs.append(freq / rope_factor)
|
||||
else:
|
||||
# Medium frequency: smooth interpolation
|
||||
smooth = (original_max_pos / (wavelen * beta_slow) - rope_factor) / (
|
||||
rope_factor * (beta_fast / beta_slow - 1)
|
||||
)
|
||||
new_freqs.append((1 - smooth) * freq / rope_factor + smooth * freq)
|
||||
freqs = torch.tensor(new_freqs, dtype=torch.float32)
|
||||
|
||||
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
|
||||
return torch.cos(angles).to(device), torch.sin(angles).to(device)
|
||||
|
||||
@@ -759,7 +789,22 @@ def main():
|
||||
final_norm_w = all_weights.get("model.norm.weight")
|
||||
if final_norm_w is not None:
|
||||
final_norm_w = final_norm_w.to('cuda:0')
|
||||
rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
|
||||
# Build RoPE caches with YaRN scaling from model config
|
||||
rope_params = cfg.get("rope_parameters", {})
|
||||
rope_type = rope_params.get("rope_type", "default")
|
||||
rope_factor = rope_params.get("factor", 1.0)
|
||||
rope_theta = rope_params.get("rope_theta", cfg.get("rope_theta", 10000.0))
|
||||
original_max_pos = rope_params.get("original_max_position_embeddings", 4096)
|
||||
beta_fast = rope_params.get("beta_fast", 32)
|
||||
beta_slow = rope_params.get("beta_slow", 1)
|
||||
print(f"RoPE: type={rope_type} factor={rope_factor} theta={rope_theta} "
|
||||
f"orig_max_pos={original_max_pos} beta_fast={beta_fast} beta_slow={beta_slow}", flush=True)
|
||||
rope_caches = {g: build_rope_cache(
|
||||
8192, rd, f"cuda:{g}", theta=rope_theta,
|
||||
rope_type=rope_type, rope_factor=rope_factor,
|
||||
original_max_pos=original_max_pos,
|
||||
beta_fast=beta_fast, beta_slow=beta_slow
|
||||
) for g in range(NUM_GPUS)}
|
||||
|
||||
# ==== KV caches (one per layer on its GPU) ====
|
||||
kv_caches = {}
|
||||
|
||||
Reference in New Issue
Block a user