Files
nvfp4-megamoe-kernel/dsv4/ops/rope.py
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

77 lines
2.6 KiB
Python

"""Inverse RoPE + NVFP4 wo_a grouped GEMM for DeepSeek V4 attention.
Replaces:
1. fused_inv_rope_fp8_quant (CUDA kernel) → inverse_rope_bf16 (Python)
2. deepseek_v4_fp8_einsum (DeepGEMM) → CuTeDSL NVFP4 grouped GEMM
The inverse RoPE is the conjugate rotation that undoes the RoPE applied
during attention. DeepSeek V4 uses GPT-J style (interleaved) RoPE.
For the RoPE portion of each head (last rope_dim=64 dims):
- Pair elements (x[2i], x[2i+1]) — interleaved (GPT-J style)
- Inverse (conjugate rotation):
x[2i] = x'[2i] * cos(θ_i) + x'[2i+1] * sin(θ_i)
x[2i+1] = -x'[2i] * sin(θ_i) + x'[2i+1] * cos(θ_i)
"""
import torch
def inverse_rope_bf16(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
nope_dim: int = 448,
rope_dim: int = 64,
) -> torch.Tensor:
"""Apply inverse RoPE to attention output in BF16.
This is a pure-Python replacement for vLLM's
fused_inv_rope_fp8_quant CUDA kernel. It only does the inverse
RoPE (no FP8 quantization) since we quantize to NVFP4 instead.
Args:
o: (num_tokens, n_local_heads, head_dim) BF16 attention output
positions: (num_tokens,) int64 token positions
cos_sin_cache: (max_pos, rope_dim) float32 — cos||sin concatenated
nope_dim: number of non-RoPE dims per head (448)
rope_dim: number of RoPE dims per head (64)
Returns:
(num_tokens, n_local_heads, head_dim) BF16 with inverse RoPE applied
"""
num_tokens, num_heads, head_dim = o.shape
half_rope = rope_dim // 2
# Get cos/sin for each position: (num_tokens, half_rope)
cos_all = cos_sin_cache[positions, :half_rope] # (T, 32)
sin_all = cos_sin_cache[positions, half_rope:] # (T, 32)
# Expand for broadcasting: (T, 1, 32) → broadcasts over heads
cos_all = cos_all.unsqueeze(1).to(o.dtype)
sin_all = sin_all.unsqueeze(1).to(o.dtype)
# Extract RoPE portion: (T, H, rope_dim)
o_rope = o[:, :, nope_dim:]
# Split into even/odd pairs (interleaved GPT-J style)
o_even = o_rope[:, :, 0::2] # (T, H, 32)
o_odd = o_rope[:, :, 1::2] # (T, H, 32)
# Inverse rotation (conjugate):
# inv[2i] = x[2i] * cos + x[2i+1] * sin
# inv[2i+1] = -x[2i] * sin + x[2i+1] * cos
inv_even = o_even * cos_all + o_odd * sin_all
inv_odd = -o_even * sin_all + o_odd * cos_all
# Interleave back
o_inv = torch.empty_like(o_rope)
o_inv[:, :, 0::2] = inv_even
o_inv[:, :, 1::2] = inv_odd
# Copy NoPE portion unchanged, replace RoPE portion
result = o.clone()
result[:, :, nope_dim:] = o_inv
return result