79 lines
3.3 KiB
Python
79 lines
3.3 KiB
Python
"""Test: check what CuTeDSL math operations are available."""
|
|
import sys
|
|
import os
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
|
|
def test_cute_math_api():
|
|
"""Enumerate available CuTeDSL math/arch operations."""
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
|
|
# Check cute.math module
|
|
print("=== cute.math attributes ===")
|
|
if hasattr(cute, 'math'):
|
|
for attr in sorted(dir(cute.math)):
|
|
if not attr.startswith('_'):
|
|
print(f" cute.math.{attr}")
|
|
else:
|
|
print(" cute.math does not exist")
|
|
|
|
# Check cute.arch module for math
|
|
print("\n=== cute.arch math-related attributes ===")
|
|
if hasattr(cute, 'arch'):
|
|
for attr in sorted(dir(cute.arch)):
|
|
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'sin', 'cos', 'rsqrt', 'rcp', 'fma', 'div']):
|
|
print(f" cute.arch.{attr}")
|
|
|
|
# Check cute directly for math
|
|
print("\n=== cute math-related attributes ===")
|
|
for attr in sorted(dir(cute)):
|
|
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'sin', 'cos', 'rsqrt', 'rcp']):
|
|
print(f" cute.{attr}")
|
|
|
|
# Check cutlass module for math
|
|
print("\n=== cutlass math-related attributes ===")
|
|
for attr in sorted(dir(cutlass)):
|
|
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'rsqrt', 'rcp']):
|
|
print(f" cutlass.{attr}")
|
|
|
|
# Check if cute.exp exists
|
|
print(f"\n=== Key functions ===")
|
|
print(f" cute.exp exists: {hasattr(cute, 'exp')}")
|
|
print(f" cute.log exists: {hasattr(cute, 'log')}")
|
|
print(f" cute.sqrt exists: {hasattr(cute, 'sqrt')}")
|
|
print(f" cute.math exists: {hasattr(cute, 'math')}")
|
|
|
|
if hasattr(cute, 'math'):
|
|
print(f" cute.math.fmax exists: {hasattr(cute.math, 'fmax')}")
|
|
print(f" cute.math.fmin exists: {hasattr(cute.math, 'fmin')}")
|
|
print(f" cute.math.absf exists: {hasattr(cute.math, 'absf')}")
|
|
print(f" cute.math.sqrt exists: {hasattr(cute.math, 'sqrt')}")
|
|
print(f" cute.math.log exists: {hasattr(cute.math, 'log')}")
|
|
print(f" cute.math.exp exists: {hasattr(cute.math, 'exp')}")
|
|
print(f" cute.math.rsqrt exists: {hasattr(cute.math, 'rsqrt')}")
|
|
print(f" cute.math.rcp exists: {hasattr(cute.math, 'rcp')}")
|
|
print(f" cute.math.sin exists: {hasattr(cute.math, 'sin')}")
|
|
print(f" cute.math.cos exists: {hasattr(cute.math, 'cos')}")
|
|
print(f" cute.math.copysign exists: {hasattr(cute.math, 'copysign')}")
|
|
print(f" cute.math.clamp exists: {hasattr(cute.math, 'clamp')}")
|
|
|
|
# Check arch operations
|
|
print(f"\n cute.arch.fmax exists: {hasattr(cute.arch, 'fmax')}")
|
|
print(f" cute.arch.fmin exists: {hasattr(cute.arch, 'fmin')}")
|
|
|
|
# Try to find math operations in cutlass._mlir_ops or similar
|
|
print("\n=== MLIR operations ===")
|
|
for mod_name in ['cutlass._mlir_ops', 'cutlass.mlir', 'cutlass.cute._mlir']:
|
|
try:
|
|
mod = __import__(mod_name, fromlist=[''])
|
|
math_attrs = [a for a in dir(mod) if any(k in a.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'rsqrt'])]
|
|
if math_attrs:
|
|
print(f" {mod_name}: {math_attrs}")
|
|
except ImportError:
|
|
pass
|
|
|
|
print("\nDone.")
|
|
|
|
if __name__ == "__main__":
|
|
test_cute_math_api()
|