"""Inverse RoPE + NVFP4 wo_a grouped GEMM for DeepSeek V4 attention. Replaces: 1. fused_inv_rope_fp8_quant (CUDA kernel) → inverse_rope_bf16 (Python) 2. deepseek_v4_fp8_einsum (DeepGEMM) → CuTeDSL NVFP4 grouped GEMM The inverse RoPE is the conjugate rotation that undoes the RoPE applied during attention. DeepSeek V4 uses GPT-J style (interleaved) RoPE. For the RoPE portion of each head (last rope_dim=64 dims): - Pair elements (x[2i], x[2i+1]) — interleaved (GPT-J style) - Inverse (conjugate rotation): x[2i] = x'[2i] * cos(θ_i) + x'[2i+1] * sin(θ_i) x[2i+1] = -x'[2i] * sin(θ_i) + x'[2i+1] * cos(θ_i) """ import torch def inverse_rope_bf16( o: torch.Tensor, positions: torch.Tensor, cos_sin_cache: torch.Tensor, nope_dim: int = 448, rope_dim: int = 64, ) -> torch.Tensor: """Apply inverse RoPE to attention output in BF16. This is a pure-Python replacement for vLLM's fused_inv_rope_fp8_quant CUDA kernel. It only does the inverse RoPE (no FP8 quantization) since we quantize to NVFP4 instead. Args: o: (num_tokens, n_local_heads, head_dim) BF16 attention output positions: (num_tokens,) int64 token positions cos_sin_cache: (max_pos, rope_dim) float32 — cos||sin concatenated nope_dim: number of non-RoPE dims per head (448) rope_dim: number of RoPE dims per head (64) Returns: (num_tokens, n_local_heads, head_dim) BF16 with inverse RoPE applied """ num_tokens, num_heads, head_dim = o.shape half_rope = rope_dim // 2 # Get cos/sin for each position: (num_tokens, half_rope) cos_all = cos_sin_cache[positions, :half_rope] # (T, 32) sin_all = cos_sin_cache[positions, half_rope:] # (T, 32) # Expand for broadcasting: (T, 1, 32) → broadcasts over heads cos_all = cos_all.unsqueeze(1).to(o.dtype) sin_all = sin_all.unsqueeze(1).to(o.dtype) # Extract RoPE portion: (T, H, rope_dim) o_rope = o[:, :, nope_dim:] # Split into even/odd pairs (interleaved GPT-J style) o_even = o_rope[:, :, 0::2] # (T, H, 32) o_odd = o_rope[:, :, 1::2] # (T, H, 32) # Inverse rotation (conjugate): # inv[2i] = x[2i] * cos + x[2i+1] * sin # inv[2i+1] = -x[2i] * sin + x[2i+1] * cos inv_even = o_even * cos_all + o_odd * sin_all inv_odd = -o_even * sin_all + o_odd * cos_all # Interleave back o_inv = torch.empty_like(o_rope) o_inv[:, :, 0::2] = inv_even o_inv[:, :, 1::2] = inv_odd # Copy NoPE portion unchanged, replace RoPE portion result = o.clone() result[:, :, nope_dim:] = o_inv return result