diff --git a/tests/unit/test_production_fmha_layer.py b/tests/unit/test_production_fmha_layer.py index c548749a..879182e6 100644 --- a/tests/unit/test_production_fmha_layer.py +++ b/tests/unit/test_production_fmha_layer.py @@ -20,6 +20,7 @@ import torch.nn.functional as F CHECKPOINT_DIR = os.environ.get( "CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4") NUM_GPUS = int(os.environ.get("NUM_GPUS", "8")) +DEVICE = "cuda:0" def cosine(a, b):