[ROCm][CI] Pin test_hybrid test to TRITON_ATTN on ROCm (#38381)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
Micah Williamson
2026-03-30 15:26:46 -05:00
committed by GitHub
parent 12701e8af2
commit d9c7db18da

View File

@@ -57,6 +57,8 @@ FP32_STATE_MODELS = [
# Avoid OOM
MAX_NUM_SEQS = 4
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "auto"
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@@ -82,7 +84,9 @@ def test_models(
example_prompts, max_tokens, num_logprobs
)
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
with vllm_runner(
model, max_num_seqs=MAX_NUM_SEQS, attention_backend=ATTN_BACKEND
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs
)
@@ -157,6 +161,7 @@ def test_chunked_prefill_with_parallel_sampling(
# forces prefill chunks with decoding
max_num_batched_tokens=MAX_NUM_SEQS * 3,
max_num_seqs=MAX_NUM_SEQS,
attention_backend=ATTN_BACKEND,
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
@@ -301,7 +306,9 @@ def test_full_cuda_graph(
example_prompts, max_tokens, num_logprobs
)
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
with vllm_runner(
model, max_num_seqs=MAX_NUM_SEQS, attention_backend=ATTN_BACKEND
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs
)
@@ -370,6 +377,7 @@ def _get_vllm_runner_params(
"max_model_len": max_model_len,
"tensor_parallel_size": tensor_parallel_size,
"gpu_memory_utilization": 0.4,
"attention_backend": ATTN_BACKEND,
}
@@ -844,6 +852,7 @@ def test_apc_common_prefix_same_batch(
mamba_block_size=16,
enable_prefix_caching=True,
seed=42,
attention_backend=ATTN_BACKEND,
)
prompts = [
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501