diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 225418356..0f587558b 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -57,6 +57,8 @@ FP32_STATE_MODELS = [ # Avoid OOM MAX_NUM_SEQS = 4 +ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "auto" + @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @@ -82,7 +84,9 @@ def test_models( example_prompts, max_tokens, num_logprobs ) - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + with vllm_runner( + model, max_num_seqs=MAX_NUM_SEQS, attention_backend=ATTN_BACKEND + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs ) @@ -157,6 +161,7 @@ def test_chunked_prefill_with_parallel_sampling( # forces prefill chunks with decoding max_num_batched_tokens=MAX_NUM_SEQS * 3, max_num_seqs=MAX_NUM_SEQS, + attention_backend=ATTN_BACKEND, ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) @@ -301,7 +306,9 @@ def test_full_cuda_graph( example_prompts, max_tokens, num_logprobs ) - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + with vllm_runner( + model, max_num_seqs=MAX_NUM_SEQS, attention_backend=ATTN_BACKEND + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs ) @@ -370,6 +377,7 @@ def _get_vllm_runner_params( "max_model_len": max_model_len, "tensor_parallel_size": tensor_parallel_size, "gpu_memory_utilization": 0.4, + "attention_backend": ATTN_BACKEND, } @@ -844,6 +852,7 @@ def test_apc_common_prefix_same_batch( mamba_block_size=16, enable_prefix_caching=True, seed=42, + attention_backend=ATTN_BACKEND, ) prompts = [ "hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501