- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
"""Inverse RoPE + NVFP4 wo_a grouped GEMM for DeepSeek V4 attention.
|
|
|
|
Replaces:
|
|
1. fused_inv_rope_fp8_quant (CUDA kernel) → inverse_rope_bf16 (Python)
|
|
2. deepseek_v4_fp8_einsum (DeepGEMM) → CuTeDSL NVFP4 grouped GEMM
|
|
|
|
The inverse RoPE is the conjugate rotation that undoes the RoPE applied
|
|
during attention. DeepSeek V4 uses GPT-J style (interleaved) RoPE.
|
|
|
|
For the RoPE portion of each head (last rope_dim=64 dims):
|
|
- Pair elements (x[2i], x[2i+1]) — interleaved (GPT-J style)
|
|
- Inverse (conjugate rotation):
|
|
x[2i] = x'[2i] * cos(θ_i) + x'[2i+1] * sin(θ_i)
|
|
x[2i+1] = -x'[2i] * sin(θ_i) + x'[2i+1] * cos(θ_i)
|
|
"""
|
|
|
|
import torch
|
|
|
|
|
|
def inverse_rope_bf16(
|
|
o: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
nope_dim: int = 448,
|
|
rope_dim: int = 64,
|
|
) -> torch.Tensor:
|
|
"""Apply inverse RoPE to attention output in BF16.
|
|
|
|
This is a pure-Python replacement for vLLM's
|
|
fused_inv_rope_fp8_quant CUDA kernel. It only does the inverse
|
|
RoPE (no FP8 quantization) since we quantize to NVFP4 instead.
|
|
|
|
Args:
|
|
o: (num_tokens, n_local_heads, head_dim) BF16 attention output
|
|
positions: (num_tokens,) int64 token positions
|
|
cos_sin_cache: (max_pos, rope_dim) float32 — cos||sin concatenated
|
|
nope_dim: number of non-RoPE dims per head (448)
|
|
rope_dim: number of RoPE dims per head (64)
|
|
|
|
Returns:
|
|
(num_tokens, n_local_heads, head_dim) BF16 with inverse RoPE applied
|
|
"""
|
|
num_tokens, num_heads, head_dim = o.shape
|
|
half_rope = rope_dim // 2
|
|
|
|
# Get cos/sin for each position: (num_tokens, half_rope)
|
|
cos_all = cos_sin_cache[positions, :half_rope] # (T, 32)
|
|
sin_all = cos_sin_cache[positions, half_rope:] # (T, 32)
|
|
|
|
# Expand for broadcasting: (T, 1, 32) → broadcasts over heads
|
|
cos_all = cos_all.unsqueeze(1).to(o.dtype)
|
|
sin_all = sin_all.unsqueeze(1).to(o.dtype)
|
|
|
|
# Extract RoPE portion: (T, H, rope_dim)
|
|
o_rope = o[:, :, nope_dim:]
|
|
|
|
# Split into even/odd pairs (interleaved GPT-J style)
|
|
o_even = o_rope[:, :, 0::2] # (T, H, 32)
|
|
o_odd = o_rope[:, :, 1::2] # (T, H, 32)
|
|
|
|
# Inverse rotation (conjugate):
|
|
# inv[2i] = x[2i] * cos + x[2i+1] * sin
|
|
# inv[2i+1] = -x[2i] * sin + x[2i+1] * cos
|
|
inv_even = o_even * cos_all + o_odd * sin_all
|
|
inv_odd = -o_even * sin_all + o_odd * cos_all
|
|
|
|
# Interleave back
|
|
o_inv = torch.empty_like(o_rope)
|
|
o_inv[:, :, 0::2] = inv_even
|
|
o_inv[:, :, 1::2] = inv_odd
|
|
|
|
# Copy NoPE portion unchanged, replace RoPE portion
|
|
result = o.clone()
|
|
result[:, :, nope_dim:] = o_inv
|
|
|
|
return result
|