diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d0a67cf84..fe48a6006 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -385,8 +385,10 @@ class GroupCoordinator: self.cpu_group, 1 << 22, 6 ) + # TODO(#35915): Remove is_tpu() check once tpu_inference + # overrides use_custom_op_collectives() to return True. self.use_custom_op_call = ( - current_platform.is_cuda_alike() or current_platform.is_tpu() + current_platform.is_tpu() or current_platform.use_custom_op_collectives() ) self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index af627964f..d3d75d883 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -574,9 +574,13 @@ class CudaPlatformBase(Platform): return True @classmethod - def num_compute_units(cls, device_id=0): + def num_compute_units(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties(device_id).multi_processor_count + @classmethod + def use_custom_op_collectives(cls) -> bool: + return True + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 5dae76757..3b56001ed 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -654,6 +654,15 @@ class Platform: """ return False + @classmethod + def use_custom_op_collectives(cls) -> bool: + """ + Whether this platform should use torch.ops.vllm.* custom ops for collectives. + + Returns False by default - platforms must explicitly opt-in. + """ + return False + @classmethod def use_sync_weight_loader(cls) -> bool: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 94675e3c9..56d654961 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -820,5 +820,9 @@ class RocmPlatform(Platform): return True @classmethod - def num_compute_units(cls, device_id=0): + def num_compute_units(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties(device_id).multi_processor_count + + @classmethod + def use_custom_op_collectives(cls) -> bool: + return True