Fix device mismatch in test
This commit is contained in:
@@ -245,7 +245,7 @@ def main():
|
||||
# (it's a plain BF16 tensor, not a quantized layer)
|
||||
|
||||
# Build cos_sin_cache
|
||||
cos_sin_cache = build_cos_sin_cache()
|
||||
cos_sin_cache = build_cos_sin_cache().to(DEVICE)
|
||||
|
||||
# Simulate attention output (what FlashMLA would produce)
|
||||
print("\n--- Simulating attention output ---")
|
||||
|
||||
Reference in New Issue
Block a user