Files
nvfp4-megamoe-kernel/dsv4/ops/rope.py

140 lines
4.9 KiB
Python

"""Inverse RoPE + NVFP4 wo_a grouped GEMM for DeepSeek V4 attention.
Replaces:
1. fused_inv_rope_fp8_quant (CUDA kernel) → inverse_rope_bf16 (Python)
2. deepseek_v4_fp8_einsum (DeepGEMM) → CuTeDSL NVFP4 grouped GEMM
The inverse RoPE is the conjugate rotation that undoes the RoPE applied
during attention. DeepSeek V4 uses GPT-J style (interleaved) RoPE.
For the RoPE portion of each head (last rope_dim=64 dims):
- Pair elements (x[2i], x[2i+1]) — interleaved (GPT-J style)
- Inverse (conjugate rotation):
x[2i] = x'[2i] * cos(θ_i) + x'[2i+1] * sin(θ_i)
x[2i+1] = -x'[2i] * sin(θ_i) + x'[2i+1] * cos(θ_i)
"""
import torch
def forward_rope_partial(
x: torch.Tensor,
positions: torch.Tensor,
rope_dim: int = 64,
head_dim: int = 512,
) -> torch.Tensor:
"""Apply partial RoPE to the last rope_dim dimensions of each head.
DSV4 uses GPT-J style (interleaved) RoPE on the last rope_dim=64 dims.
The first nope_dim=448 dims are left unchanged.
For the RoPE portion (last 64 dims of each head):
- Pair elements (x[2i], x[2i+1]) — interleaved
- Forward rotation:
x'[2i] = x[2i] * cos(θ) - x[2i+1] * sin(θ)
x'[2i+1] = x[2i] * sin(θ) + x[2i+1] * cos(θ)
Args:
x: (T, n_h * head_dim) BF16 — flat across heads
positions: (T,) int64 token positions
rope_dim: number of RoPE dims per head
head_dim: total head dimension
Returns:
(T, n_h * head_dim) BF16 with forward RoPE applied to last rope_dim dims
"""
T = x.shape[0]
n_h = x.shape[1] // head_dim
nope_dim = head_dim - rope_dim
half_rope = rope_dim // 2
# Build cos/sin cache (simple theta = 1/10000^(2i/d))
# This should match the model's cos_sin_cache, but for now compute inline
freqs = 1.0 / (10000.0 ** (torch.arange(0, rope_dim, 2, dtype=torch.float32, device=x.device) / rope_dim))
pos_float = positions.float() # (T,)
angles = torch.outer(pos_float, freqs) # (T, half_rope)
cos_vals = torch.cos(angles).unsqueeze(1) # (T, 1, half_rope) FP32
sin_vals = torch.sin(angles).unsqueeze(1) # FP32 for accuracy
# Reshape x to (T, n_h, head_dim)
x_heads = x.reshape(T, n_h, head_dim)
# Extract RoPE portion — compute in FP32
x_rope = x_heads[:, :, nope_dim:].float() # (T, n_h, rope_dim) FP32
x_even = x_rope[:, :, 0::2] # (T, n_h, half_rope)
x_odd = x_rope[:, :, 1::2] # (T, n_h, half_rope)
# Forward rotation
rot_even = x_even * cos_vals - x_odd * sin_vals
rot_odd = x_even * sin_vals + x_odd * cos_vals
# Interleave back
x_rot = torch.empty_like(x_rope)
x_rot[:, :, 0::2] = rot_even
x_rot[:, :, 1::2] = rot_odd
# Copy NoPE portion unchanged
result = x_heads.clone()
result[:, :, nope_dim:] = x_rot.to(torch.bfloat16)
return result.reshape(T, n_h * head_dim)
def inverse_rope_bf16(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
nope_dim: int = 448,
rope_dim: int = 64,
) -> torch.Tensor:
"""Apply inverse RoPE to attention output in BF16.
This is a pure-Python replacement for vLLM's
fused_inv_rope_fp8_quant CUDA kernel. It only does the inverse
RoPE (no FP8 quantization) since we quantize to NVFP4 instead.
Args:
o: (num_tokens, n_local_heads, head_dim) BF16 attention output
positions: (num_tokens,) int64 token positions
cos_sin_cache: (max_pos, rope_dim) float32 — cos||sin concatenated
nope_dim: number of non-RoPE dims per head (448)
rope_dim: number of RoPE dims per head (64)
Returns:
(num_tokens, n_local_heads, head_dim) BF16 with inverse RoPE applied
"""
num_tokens, num_heads, head_dim = o.shape
half_rope = rope_dim // 2
# Get cos/sin for each position: (num_tokens, half_rope)
cos_all = cos_sin_cache[positions, :half_rope] # (T, 32)
sin_all = cos_sin_cache[positions, half_rope:] # (T, 32)
# Expand for broadcasting: (T, 1, 32) → broadcasts over heads
# CRITICAL: compute in FP32, not BF16! BF16 cos/sin destroys cos²+sin²=1
cos_all = cos_all.unsqueeze(1).float()
sin_all = sin_all.unsqueeze(1).float()
# Extract RoPE portion: (T, H, rope_dim) — compute in FP32
o_rope = o[:, :, nope_dim:].float()
# Split into even/odd pairs (interleaved GPT-J style)
o_even = o_rope[:, :, 0::2] # (T, H, 32)
o_odd = o_rope[:, :, 1::2] # (T, H, 32)
# Inverse rotation (conjugate):
# inv[2i] = x[2i] * cos + x[2i+1] * sin
# inv[2i+1] = -x[2i] * sin + x[2i+1] * cos
inv_even = o_even * cos_all + o_odd * sin_all
inv_odd = -o_even * sin_all + o_odd * cos_all
# Interleave back
o_inv = torch.empty_like(o_rope)
o_inv[:, :, 0::2] = inv_even
o_inv[:, :, 1::2] = inv_odd
# Copy NoPE portion unchanged, replace RoPE portion
result = o.clone()
result[:, :, nope_dim:] = o_inv.to(torch.bfloat16)
return result