17 lines
839 B
Python
17 lines
839 B
Python
#!/usr/bin/env python3
|
|
"""Test: quantize_activation_nvfp4 on different GPUs."""
|
|
import torch
|
|
from dsv4.ops.quantize import quantize_activation_nvfp4
|
|
|
|
torch.manual_seed(42)
|
|
|
|
for gpu in [0, 1]:
|
|
dev = f"cuda:{gpu}"
|
|
x = torch.randn(1, 7168, dtype=torch.bfloat16, device=dev) * 0.5
|
|
gsa = 0.000375
|
|
x_fp4, x_sf = quantize_activation_nvfp4(x, gsa)
|
|
has_nan = torch.isnan(x_fp4.view(torch.float16)).any().item() if x_fp4.dtype == torch.float4_e2m1fn_x2 else torch.isnan(x_fp4).any().item()
|
|
print(f"GPU {gpu} quantize: x_fp4 shape={x_fp4.shape} dtype={x_fp4.dtype} x_sf shape={x_sf.shape} has_nan={has_nan}")
|
|
print(f" x_fp4 uint8 range: [{x_fp4.view(torch.uint8).min().item()}, {x_fp4.view(torch.uint8).max().item()}]")
|
|
print(f" x_sf float range: [{x_sf.float().min().item():.6f}, {x_sf.float().max().item():.6f}]")
|