[ROCm][CI] Fix HuggingFace flash_attention_2 accuracy issue in Isaac vision encoder (#32233)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -30,3 +30,22 @@ def pytest_collection_modifyitems(config, items):
|
||||
UserWarning,
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
|
||||
def patch_hf_vision_attn_for_rocm(model):
|
||||
"""Force SDPA for HF vision encoders on ROCm.
|
||||
|
||||
HF's flash_attention_2 has accuracy issues on ROCm that bypass
|
||||
torch.backends.cuda settings. This forces SDPA which then uses
|
||||
math_sdp via the pytest_collection_modifyitems settings.
|
||||
"""
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
|
||||
inner = getattr(model, "model", model)
|
||||
|
||||
if hasattr(inner, "vision_embedding"):
|
||||
vit = inner.vision_embedding[0]
|
||||
for layer in vit.encoder.layers:
|
||||
if hasattr(layer, "self_attn"):
|
||||
layer.self_attn.vision_config._attn_implementation = "sdpa"
|
||||
|
||||
Reference in New Issue
Block a user