refactor: abstract deepgemm support into platform (#37519)
Co-authored-by: sherryC41 <sherry.c.c41@gmail.com>
This commit is contained in:
@@ -511,6 +511,11 @@ class CudaPlatformBase(Platform):
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_deep_gemm(cls) -> bool:
|
||||
"""Currently, only Hopper and Blackwell GPUs are supported."""
|
||||
return cls.is_device_capability(90) or cls.is_device_capability_family(100)
|
||||
|
||||
@classmethod
|
||||
def num_compute_units(cls, device_id: int = 0) -> int:
|
||||
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
||||
|
||||
@@ -712,6 +712,13 @@ class Platform:
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def support_deep_gemm(cls) -> bool:
|
||||
"""
|
||||
Returns if DeepGEMM is supported by the current platform.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def use_custom_op_collectives(cls) -> bool:
|
||||
"""
|
||||
|
||||
@@ -70,10 +70,7 @@ def is_deep_gemm_supported() -> bool:
|
||||
"""Return `True` if DeepGEMM is supported on the current platform.
|
||||
Currently, only Hopper and Blackwell GPUs are supported.
|
||||
"""
|
||||
is_supported_arch = current_platform.is_cuda() and (
|
||||
current_platform.is_device_capability(90)
|
||||
or current_platform.is_device_capability_family(100)
|
||||
)
|
||||
is_supported_arch = current_platform.support_deep_gemm()
|
||||
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user