Add platform method to enable custom collective ops registration (#34760)

Signed-off-by: Naina Kuruballi Mahesh <nainakm@meta.com>
This commit is contained in:
nkm-meta
2026-03-04 16:50:32 -08:00
committed by GitHub
parent 2ed4722e26
commit 792cbd64ca
4 changed files with 22 additions and 3 deletions

View File

@@ -385,8 +385,10 @@ class GroupCoordinator:
self.cpu_group, 1 << 22, 6
)
# TODO(#35915): Remove is_tpu() check once tpu_inference
# overrides use_custom_op_collectives() to return True.
self.use_custom_op_call = (
current_platform.is_cuda_alike() or current_platform.is_tpu()
current_platform.is_tpu() or current_platform.use_custom_op_collectives()
)
self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(

View File

@@ -574,9 +574,13 @@ class CudaPlatformBase(Platform):
return True
@classmethod
def num_compute_units(cls, device_id=0):
def num_compute_units(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(device_id).multi_processor_count
@classmethod
def use_custom_op_collectives(cls) -> bool:
return True
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@@ -654,6 +654,15 @@ class Platform:
"""
return False
@classmethod
def use_custom_op_collectives(cls) -> bool:
"""
Whether this platform should use torch.ops.vllm.* custom ops for collectives.
Returns False by default - platforms must explicitly opt-in.
"""
return False
@classmethod
def use_sync_weight_loader(cls) -> bool:
"""

View File

@@ -820,5 +820,9 @@ class RocmPlatform(Platform):
return True
@classmethod
def num_compute_units(cls, device_id=0):
def num_compute_units(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(device_id).multi_processor_count
@classmethod
def use_custom_op_collectives(cls) -> bool:
return True