diff --git a/tests/test_o_projection_b200.py b/tests/test_o_projection_b200.py index edf07e83..1640db0d 100644 --- a/tests/test_o_projection_b200.py +++ b/tests/test_o_projection_b200.py @@ -245,7 +245,7 @@ def main(): # (it's a plain BF16 tensor, not a quantized layer) # Build cos_sin_cache - cos_sin_cache = build_cos_sin_cache() + cos_sin_cache = build_cos_sin_cache().to(DEVICE) # Simulate attention output (what FlashMLA would produce) print("\n--- Simulating attention output ---")