[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers (#9982)

Signed-off-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
sroy745
2024-11-12 10:53:57 -08:00
committed by GitHub
parent 7c65527918
commit b41fb9d3b1
5 changed files with 117 additions and 80 deletions

View File

@@ -243,6 +243,8 @@ def test_rope_customization():
assert longchat_model_config.max_model_len == 4096
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Encoder Decoder models not supported on ROCm.")
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
("facebook/opt-125m", False),
("facebook/bart-base", True),