Files
nvfp4-megamoe-kernel/dsv4/ops/rope.py

77 lines
2.6 KiB
Python
Raw Normal View History

"""Inverse RoPE + NVFP4 wo_a grouped GEMM for DeepSeek V4 attention.
Replaces:
1. fused_inv_rope_fp8_quant (CUDA kernel) inverse_rope_bf16 (Python)
2. deepseek_v4_fp8_einsum (DeepGEMM) CuTeDSL NVFP4 grouped GEMM
The inverse RoPE is the conjugate rotation that undoes the RoPE applied
during attention. DeepSeek V4 uses GPT-J style (interleaved) RoPE.
For the RoPE portion of each head (last rope_dim=64 dims):
- Pair elements (x[2i], x[2i+1]) interleaved (GPT-J style)
- Inverse (conjugate rotation):
x[2i] = x'[2i] * cos(θ_i) + x'[2i+1] * sin(θ_i)
x[2i+1] = -x'[2i] * sin(θ_i) + x'[2i+1] * cos(θ_i)
"""
import torch
def inverse_rope_bf16(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
nope_dim: int = 448,
rope_dim: int = 64,
) -> torch.Tensor:
"""Apply inverse RoPE to attention output in BF16.
This is a pure-Python replacement for vLLM's
fused_inv_rope_fp8_quant CUDA kernel. It only does the inverse
RoPE (no FP8 quantization) since we quantize to NVFP4 instead.
Args:
o: (num_tokens, n_local_heads, head_dim) BF16 attention output
positions: (num_tokens,) int64 token positions
cos_sin_cache: (max_pos, rope_dim) float32 cos||sin concatenated
nope_dim: number of non-RoPE dims per head (448)
rope_dim: number of RoPE dims per head (64)
Returns:
(num_tokens, n_local_heads, head_dim) BF16 with inverse RoPE applied
"""
num_tokens, num_heads, head_dim = o.shape
half_rope = rope_dim // 2
# Get cos/sin for each position: (num_tokens, half_rope)
cos_all = cos_sin_cache[positions, :half_rope] # (T, 32)
sin_all = cos_sin_cache[positions, half_rope:] # (T, 32)
# Expand for broadcasting: (T, 1, 32) → broadcasts over heads
cos_all = cos_all.unsqueeze(1).to(o.dtype)
sin_all = sin_all.unsqueeze(1).to(o.dtype)
# Extract RoPE portion: (T, H, rope_dim)
o_rope = o[:, :, nope_dim:]
# Split into even/odd pairs (interleaved GPT-J style)
o_even = o_rope[:, :, 0::2] # (T, H, 32)
o_odd = o_rope[:, :, 1::2] # (T, H, 32)
# Inverse rotation (conjugate):
# inv[2i] = x[2i] * cos + x[2i+1] * sin
# inv[2i+1] = -x[2i] * sin + x[2i+1] * cos
inv_even = o_even * cos_all + o_odd * sin_all
inv_odd = -o_even * sin_all + o_odd * cos_all
# Interleave back
o_inv = torch.empty_like(o_rope)
o_inv[:, :, 0::2] = inv_even
o_inv[:, :, 1::2] = inv_odd
# Copy NoPE portion unchanged, replace RoPE portion
result = o.clone()
result[:, :, nope_dim:] = o_inv
return result