[ROCm][CI] fix get_valid_backends (#32787)
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
@@ -83,8 +83,14 @@ EXCLUDED_BACKENDS = {AttentionBackendEnum.FLEX_ATTENTION}
|
||||
|
||||
|
||||
def get_available_attention_backends() -> list[str]:
|
||||
if not hasattr(current_platform, "get_valid_backends"):
|
||||
return ["FLASH_ATTN"]
|
||||
# Check if get_valid_backends is actually defined in the platform class
|
||||
# (not just returning None from __getattr__)
|
||||
get_valid_backends = getattr(current_platform.__class__, "get_valid_backends", None)
|
||||
if get_valid_backends is None:
|
||||
if current_platform.is_rocm():
|
||||
return ["TRITON_ATTN"]
|
||||
else:
|
||||
return ["FLASH_ATTN"]
|
||||
|
||||
device_capability = current_platform.get_device_capability()
|
||||
if device_capability is None:
|
||||
|
||||
Reference in New Issue
Block a user