fucken aye
This commit is contained in:
@@ -63,26 +63,64 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
scale_exp = (scale_bits >> 23) & 0xFF
|
||||
scale_mant = scale_bits & 0x7FFFFF
|
||||
|
||||
# Convert FP32 → E4M3 manually
|
||||
e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7
|
||||
e4m3_exp = tl.maximum(e4m3_exp, 0)
|
||||
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
||||
e4m3_mant = scale_mant >> 20
|
||||
round_bit = (scale_mant >> 19) & 1
|
||||
e4m3_mant = e4m3_mant + round_bit
|
||||
# Convert FP32 → E4M3 manually (with subnormal support)
|
||||
# FP32 bias=127, E4M3 bias=7 → raw exp = scale_exp - 120
|
||||
e4m3_exp_raw = scale_exp - 120 # can be negative → subnormal
|
||||
|
||||
# Normal path: exp >= 1, just truncate mantissa top 3 bits
|
||||
# RNE rounding: need guard (bit 19), sticky (OR of bits 18:0), and LSB of result
|
||||
normal_mant = scale_mant >> 20
|
||||
guard_bit = (scale_mant >> 19) & 1
|
||||
sticky_bit = tl.where((scale_mant & 0x7FFFF) != 0, 1, 0) # OR of bits [18:0]
|
||||
result_lsb = normal_mant & 1
|
||||
# RNE: round up if (guard=1 and sticky=1) or (guard=1 and sticky=0 and lsb=1)
|
||||
round_up = guard_bit & (sticky_bit | result_lsb)
|
||||
normal_mant = normal_mant + round_up
|
||||
normal_exp = e4m3_exp_raw
|
||||
|
||||
# Subnormal path: exp_raw <= 0
|
||||
# Insert implicit leading 1 and right-shift by (1 - exp_raw)
|
||||
# E4M3 subnormal: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6
|
||||
# So we need: (1 + mant_fp32/2^23) * 2^(exp_raw - 7) = (shifted_mant/8) * 2^-6
|
||||
# shifted_mant = (implicit_1 | mant_fp32) >> (1 - exp_raw - 1) then take top 3 bits
|
||||
shift = 1 - e4m3_exp_raw # positive when subnormal
|
||||
mant_with_leading = (0x800000 | scale_mant) # insert implicit 1
|
||||
# Right-shift to get into the 3-bit E4M3 mantissa window
|
||||
# We want bits [shift+19 : shift+23) of mant_with_leading for 3 mantissa bits + 1 round bit
|
||||
subnormal_mant = (mant_with_leading >> (shift.to(tl.int32) + 20)) & 0x7
|
||||
sub_guard_bit = (mant_with_leading >> (shift.to(tl.int32) + 19)) & 1
|
||||
# Sticky: OR of all bits below the guard bit in the shifted result
|
||||
# shift ≤ 8 in practice (amax floor = 1e-4 → scale ≈ 2^-15 → exp_raw ≈ -7), so mask ≤ 2^27
|
||||
sub_sticky_mask = (1 << (shift.to(tl.int32) + 19)) - 1
|
||||
sub_sticky_bit = tl.where((mant_with_leading & sub_sticky_mask) != 0, 1, 0)
|
||||
sub_result_lsb = subnormal_mant & 1
|
||||
sub_round_up = sub_guard_bit & (sub_sticky_bit | sub_result_lsb)
|
||||
subnormal_mant = subnormal_mant + sub_round_up
|
||||
|
||||
is_normal = e4m3_exp_raw >= 1
|
||||
e4m3_mant = tl.where(is_normal, normal_mant, subnormal_mant)
|
||||
e4m3_exp = tl.where(is_normal, normal_exp, 0) # exp=0 for subnormals
|
||||
|
||||
# Handle mantissa overflow after rounding
|
||||
overflow = e4m3_mant >= 8
|
||||
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
|
||||
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
|
||||
e4m3_exp = tl.maximum(e4m3_exp, 0)
|
||||
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
||||
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant
|
||||
|
||||
# Reconstruct dequantized scale for E2M1 quantization
|
||||
e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126)
|
||||
# Reconstruct dequantized scale by decoding the STORED E4M3 bits.
|
||||
# This guarantees the E2M1 quantization divides by exactly the value
|
||||
# the CUDA kernel will multiply back — same bits, single decode, no
|
||||
# possibility of encode/decode disagreement.
|
||||
stored_exp = (scale_e4m3_bits >> 3) & 0xF
|
||||
stored_mant = scale_e4m3_bits & 0x7
|
||||
e4m3_exp_for_recon = tl.maximum(stored_exp.to(tl.int32) - 7, -126)
|
||||
two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23
|
||||
two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True)
|
||||
normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp
|
||||
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625
|
||||
e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value)
|
||||
normal_value = (1.0 + stored_mant.to(tl.float32) / 8.0) * two_pow_exp
|
||||
subnormal_value = (stored_mant.to(tl.float32) / 8.0) * 0.015625
|
||||
e4m3_value = tl.where(stored_exp == 0, subnormal_value, normal_value)
|
||||
|
||||
# ---- E2M1 FP4 quantization (unpacked, 1 byte/element) ----
|
||||
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
"""Minimal test for fp8_nvfp4_mega_moe kernel with synthetic data.
|
||||
|
||||
Run inside the vllm container:
|
||||
python3 /patches/test_nvfp4_mega_moe.py
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import sys
|
||||
|
||||
def test_nvfp4_mega_moe():
|
||||
# Use dimensions that satisfy all alignment requirements:
|
||||
# - hidden and intermediate_hidden must be multiples of 128 and 64
|
||||
# - block_m will be at least 32 (SMEM alignment: 32 * 64 = 2048 >= 1024)
|
||||
num_experts = 2
|
||||
num_tokens = 32 # must be multiple of alignment
|
||||
top_k = 2
|
||||
hidden = 512 # multiple of 128 and 64
|
||||
intermediate_hidden = 1024 # multiple of 128 and 64
|
||||
|
||||
device = "cuda"
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# Single-rank process group for SymmBuffer
|
||||
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
|
||||
os.environ.setdefault("MASTER_PORT", "29501")
|
||||
os.environ.setdefault("RANK", "0")
|
||||
os.environ.setdefault("WORLD_SIZE", "1")
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group("nccl")
|
||||
group = dist.new_group()
|
||||
|
||||
from deep_gemm.mega import (
|
||||
fp8_nvfp4_mega_moe,
|
||||
get_symm_buffer_for_nvfp4_mega_moe,
|
||||
transform_nvfp4_weights_for_mega_moe,
|
||||
)
|
||||
|
||||
# --- Weights: random NVFP4 ---
|
||||
w13_weight = torch.randint(0, 256, (num_experts, 2 * intermediate_hidden, hidden // 2),
|
||||
dtype=torch.uint8, device=device).view(torch.int8)
|
||||
w13_weight_scale = torch.randn(num_experts, 2 * intermediate_hidden, hidden // 16,
|
||||
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
|
||||
w13_weight_scale_2 = torch.ones(num_experts, device=device) # global scale = 1 for simplicity
|
||||
w2_weight = torch.randint(0, 256, (num_experts, hidden, intermediate_hidden // 2),
|
||||
dtype=torch.uint8, device=device).view(torch.int8)
|
||||
w2_weight_scale = torch.randn(num_experts, hidden, intermediate_hidden // 16,
|
||||
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
|
||||
w2_weight_scale_2 = torch.ones(num_experts, device=device)
|
||||
|
||||
print("Transforming weights...")
|
||||
l1_weights, l2_weights = transform_nvfp4_weights_for_mega_moe(
|
||||
(w13_weight, w13_weight_scale),
|
||||
(w2_weight, w2_weight_scale),
|
||||
l1_weight_scale_2=w13_weight_scale_2,
|
||||
l2_weight_scale_2=w2_weight_scale_2,
|
||||
)
|
||||
for name, t in [("l1_w", l1_weights[0]), ("l1_w_sf", l1_weights[1]),
|
||||
("l2_w", l2_weights[0]), ("l2_w_sf", l2_weights[1])]:
|
||||
print(f" {name}: dtype={t.dtype} shape={tuple(t.shape)} strides={t.stride()} contig={t.is_contiguous()}")
|
||||
|
||||
# --- Symm buffer ---
|
||||
print("Creating symm buffer...")
|
||||
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
|
||||
group, num_experts, num_tokens, top_k, hidden, intermediate_hidden)
|
||||
for name, t in [("x", symm_buffer.x), ("x_sf", symm_buffer.x_sf),
|
||||
("l1_acts", symm_buffer.l1_acts), ("l1_acts_sf", symm_buffer.l1_acts_sf),
|
||||
("l2_acts", symm_buffer.l2_acts), ("l2_acts_sf", symm_buffer.l2_acts_sf)]:
|
||||
print(f" symm_{name}: dtype={t.dtype} shape={tuple(t.shape)} strides={t.stride()}")
|
||||
|
||||
# --- Stage inputs (BF16 hidden_states → FP4 packed + UE4M3 scales) ---
|
||||
print("Staging inputs...")
|
||||
hidden_states = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.5
|
||||
topk_weights = torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=device)
|
||||
|
||||
# Import the staging kernel directly (can't import from vllm model without full init)
|
||||
import triton
|
||||
import triton.language as tl
|
||||
# Just manually pack a few tokens into FP4 for testing
|
||||
# Write actual FP4 data to the activation buffer (random but valid packed E2M1)
|
||||
symm_buffer.x[:num_tokens].copy_(
|
||||
torch.randint(0, 256, (num_tokens, hidden // 2), dtype=torch.uint8, device=device).view(torch.int8))
|
||||
# Write valid UE4M3 scales (random but non-zero)
|
||||
# x_sf shape is (tokens, hidden//64) as int32 — each int32 = 4 packed UE4M3 bytes
|
||||
# Just fill with simple non-zero int32 values (the data doesn't need to be
|
||||
# perfectly valid UE4M3 for a launch test, just non-garbage)
|
||||
symm_buffer.x_sf[:num_tokens].fill_(0x3C3C3C3C) # repeating 0x3C = ~0.5 in E4M3
|
||||
# Write topk data directly
|
||||
for i in range(num_tokens):
|
||||
for j in range(top_k):
|
||||
symm_buffer.topk_idx[i, j] = topk_ids[i, j].item()
|
||||
symm_buffer.topk_weights[i, j] = topk_weights[i, j].item()
|
||||
torch.cuda.synchronize()
|
||||
print("Buffer populated with random FP4 data")
|
||||
|
||||
# --- Run kernel ---
|
||||
y = torch.zeros(num_tokens, hidden, dtype=torch.bfloat16, device=device)
|
||||
print("Calling fp8_nvfp4_mega_moe...", flush=True)
|
||||
import signal
|
||||
timed_out = False
|
||||
def handler(signum, frame):
|
||||
nonlocal timed_out
|
||||
timed_out = True
|
||||
raise TimeoutError("Kernel timeout")
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(15) # 15 second timeout
|
||||
try:
|
||||
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer)
|
||||
torch.cuda.synchronize()
|
||||
signal.alarm(0)
|
||||
print(f"SUCCESS! y stats: min={y.min().item():.4f} max={y.max().item():.4f} mean={y.mean().item():.4f} nonzero={torch.count_nonzero(y).item()}")
|
||||
except TimeoutError:
|
||||
print("TIMEOUT: kernel did not complete in 15s (GPU hang?)")
|
||||
except Exception as e:
|
||||
signal.alarm(0)
|
||||
print(f"FAILED: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_nvfp4_mega_moe()
|
||||
Reference in New Issue
Block a user