[ROCm][CI] Fix HuggingFace flash_attention_2 accuracy issue in Isaac vision encoder (#32233)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-01-13 00:33:59 -06:00
committed by GitHub
parent 11b6af5280
commit 5e714f7ff4
2 changed files with 27 additions and 0 deletions

View File

@@ -30,3 +30,22 @@ def pytest_collection_modifyitems(config, items):
UserWarning,
stacklevel=1,
)
def patch_hf_vision_attn_for_rocm(model):
"""Force SDPA for HF vision encoders on ROCm.
HF's flash_attention_2 has accuracy issues on ROCm that bypass
torch.backends.cuda settings. This forces SDPA which then uses
math_sdp via the pytest_collection_modifyitems settings.
"""
if not current_platform.is_rocm():
return
inner = getattr(model, "model", model)
if hasattr(inner, "vision_embedding"):
vit = inner.vision_embedding[0]
for layer in vit.encoder.layers:
if hasattr(layer, "self_attn"):
layer.self_attn.vision_config._attn_implementation = "sdpa"