diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 45ba1bef9..754f2981c 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -74,7 +74,7 @@ class SMControlContextManager: "SM control is currently only supported on CUDA" ) - total_sms = num_compute_units(torch.cuda.current_device().index) + total_sms = num_compute_units(torch.cuda.current_device()) assert comm_sms < total_sms self.total_sms = total_sms