Fix device mismatch in test

This commit is contained in:
2026-05-19 06:36:22 +00:00
parent 6b4b9774d1
commit b4fee70151

View File

@@ -245,7 +245,7 @@ def main():
# (it's a plain BF16 tensor, not a quantized layer)
# Build cos_sin_cache
cos_sin_cache = build_cos_sin_cache()
cos_sin_cache = build_cos_sin_cache().to(DEVICE)
# Simulate attention output (what FlashMLA would produce)
print("\n--- Simulating attention output ---")