[ROCm][CI] Prep Tests For Change To ROCM_ATTN As New Default Backend On ROCm (#36025)
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
@@ -732,11 +732,13 @@ def test_mtp_correctness(
|
||||
method, model_name, tp_size = model_setup
|
||||
_skip_if_insufficient_gpus_for_tp(tp_size)
|
||||
|
||||
attn_backend = "TRITON_ATTN" if current_platform.is_rocm() else "auto"
|
||||
ref_llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True,
|
||||
attention_backend=attn_backend,
|
||||
)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
evaluate_llm_for_gsm8k(
|
||||
@@ -756,6 +758,7 @@ def test_mtp_correctness(
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_model_len=2048,
|
||||
attention_backend=attn_backend,
|
||||
)
|
||||
evaluate_llm_for_gsm8k(
|
||||
spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
|
||||
|
||||
Reference in New Issue
Block a user