diff --git a/tests/unit/test_nvfp4_primitives.py b/tests/unit/test_nvfp4_primitives.py index f4fa4aea..f0c3d2da 100644 --- a/tests/unit/test_nvfp4_primitives.py +++ b/tests/unit/test_nvfp4_primitives.py @@ -105,9 +105,10 @@ def test_nvfp4_primitives(): else: print(f" ❌ Scale factors are {x_sf.dtype} — unexpected!") - # Check what the gemm_runner actually passes to the kernel - from dsv4.ops.gemm_runner import Nvfp4GemmRunner - print(f" Nvfp4GemmRunner exists: {Nvfp4GemmRunner is not None}") + # Check the gemm_runner class name + import dsv4.ops.gemm_runner as gr + runner_classes = [name for name in dir(gr) if 'unner' in name or 'Gemm' in name] + print(f" gemm_runner classes: {runner_classes}") print() print("=" * 60)