diff --git a/single_shot_inference.py b/single_shot_inference.py index 74fa9d91..aacc7d99 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -189,17 +189,47 @@ class mHCBlock: # RoPE — partial, GPT-J interleaved, last rope_dim dims # ===================================================================== -def build_rope_cache(max_pos, rope_dim, device, theta=10000.0): +def build_rope_cache(max_pos, rope_dim, device, theta=10000.0, + rope_type="default", rope_factor=1.0, + original_max_pos=4096, beta_fast=32, beta_slow=1): """Build cos/sin caches for partial RoPE. CRITICAL: FP32, not BF16! BF16 quantization destroys cos²+sin²=1 identity needed for inverse RoPE. BF16 cos²+sin² can be 0.996, causing ~3% round-trip error that accumulates across 61 layers. + Supports YaRN (Yet another RoPE extensioN) scaling for long context. + The DSV4 Pro model uses rope_type='yarn' with factor=16. + Returns: (cos_cache, sin_cache) each (max_pos, rope_dim//2) FP32 """ half = rope_dim // 2 + # Base frequencies: 1 / theta^(2i/d) freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) + + if rope_type == "yarn" and rope_factor > 1.0: + # YaRN frequency scaling + # Compute wavelength thresholds + low_freq_wavelen = original_max_pos / (beta_fast * 2.0) # High-freq cutoff + high_freq_wavelen = original_max_pos / (beta_slow * 2.0) # Low-freq cutoff + + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < low_freq_wavelen: + # High frequency: no scaling + new_freqs.append(freq) + elif wavelen > high_freq_wavelen: + # Low frequency: scale by 1/factor + new_freqs.append(freq / rope_factor) + else: + # Medium frequency: smooth interpolation + smooth = (original_max_pos / (wavelen * beta_slow) - rope_factor) / ( + rope_factor * (beta_fast / beta_slow - 1) + ) + new_freqs.append((1 - smooth) * freq / rope_factor + smooth * freq) + freqs = torch.tensor(new_freqs, dtype=torch.float32) + angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs) return torch.cos(angles).to(device), torch.sin(angles).to(device) @@ -759,7 +789,22 @@ def main(): final_norm_w = all_weights.get("model.norm.weight") if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0') - rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}") for g in range(NUM_GPUS)} + # Build RoPE caches with YaRN scaling from model config + rope_params = cfg.get("rope_parameters", {}) + rope_type = rope_params.get("rope_type", "default") + rope_factor = rope_params.get("factor", 1.0) + rope_theta = rope_params.get("rope_theta", cfg.get("rope_theta", 10000.0)) + original_max_pos = rope_params.get("original_max_position_embeddings", 4096) + beta_fast = rope_params.get("beta_fast", 32) + beta_slow = rope_params.get("beta_slow", 1) + print(f"RoPE: type={rope_type} factor={rope_factor} theta={rope_theta} " + f"orig_max_pos={original_max_pos} beta_fast={beta_fast} beta_slow={beta_slow}", flush=True) + rope_caches = {g: build_rope_cache( + 8192, rd, f"cuda:{g}", theta=rope_theta, + rope_type=rope_type, rope_factor=rope_factor, + original_max_pos=original_max_pos, + beta_fast=beta_fast, beta_slow=beta_slow + ) for g in range(NUM_GPUS)} # ==== KV caches (one per layer on its GPU) ==== kv_caches = {}