Files
nvfp4-megamoe-kernel/tests/unit/test_cute_math_api.py

79 lines
3.3 KiB
Python

"""Test: check what CuTeDSL math operations are available."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def test_cute_math_api():
"""Enumerate available CuTeDSL math/arch operations."""
import cutlass
import cutlass.cute as cute
# Check cute.math module
print("=== cute.math attributes ===")
if hasattr(cute, 'math'):
for attr in sorted(dir(cute.math)):
if not attr.startswith('_'):
print(f" cute.math.{attr}")
else:
print(" cute.math does not exist")
# Check cute.arch module for math
print("\n=== cute.arch math-related attributes ===")
if hasattr(cute, 'arch'):
for attr in sorted(dir(cute.arch)):
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'sin', 'cos', 'rsqrt', 'rcp', 'fma', 'div']):
print(f" cute.arch.{attr}")
# Check cute directly for math
print("\n=== cute math-related attributes ===")
for attr in sorted(dir(cute)):
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'sin', 'cos', 'rsqrt', 'rcp']):
print(f" cute.{attr}")
# Check cutlass module for math
print("\n=== cutlass math-related attributes ===")
for attr in sorted(dir(cutlass)):
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'rsqrt', 'rcp']):
print(f" cutlass.{attr}")
# Check if cute.exp exists
print(f"\n=== Key functions ===")
print(f" cute.exp exists: {hasattr(cute, 'exp')}")
print(f" cute.log exists: {hasattr(cute, 'log')}")
print(f" cute.sqrt exists: {hasattr(cute, 'sqrt')}")
print(f" cute.math exists: {hasattr(cute, 'math')}")
if hasattr(cute, 'math'):
print(f" cute.math.fmax exists: {hasattr(cute.math, 'fmax')}")
print(f" cute.math.fmin exists: {hasattr(cute.math, 'fmin')}")
print(f" cute.math.absf exists: {hasattr(cute.math, 'absf')}")
print(f" cute.math.sqrt exists: {hasattr(cute.math, 'sqrt')}")
print(f" cute.math.log exists: {hasattr(cute.math, 'log')}")
print(f" cute.math.exp exists: {hasattr(cute.math, 'exp')}")
print(f" cute.math.rsqrt exists: {hasattr(cute.math, 'rsqrt')}")
print(f" cute.math.rcp exists: {hasattr(cute.math, 'rcp')}")
print(f" cute.math.sin exists: {hasattr(cute.math, 'sin')}")
print(f" cute.math.cos exists: {hasattr(cute.math, 'cos')}")
print(f" cute.math.copysign exists: {hasattr(cute.math, 'copysign')}")
print(f" cute.math.clamp exists: {hasattr(cute.math, 'clamp')}")
# Check arch operations
print(f"\n cute.arch.fmax exists: {hasattr(cute.arch, 'fmax')}")
print(f" cute.arch.fmin exists: {hasattr(cute.arch, 'fmin')}")
# Try to find math operations in cutlass._mlir_ops or similar
print("\n=== MLIR operations ===")
for mod_name in ['cutlass._mlir_ops', 'cutlass.mlir', 'cutlass.cute._mlir']:
try:
mod = __import__(mod_name, fromlist=[''])
math_attrs = [a for a in dir(mod) if any(k in a.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'rsqrt'])]
if math_attrs:
print(f" {mod_name}: {math_attrs}")
except ImportError:
pass
print("\nDone.")
if __name__ == "__main__":
test_cute_math_api()