[ROCm][CI] Pin test_hybrid test to TRITON_ATTN on ROCm (#38381)
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
@@ -57,6 +57,8 @@ FP32_STATE_MODELS = [
|
||||
# Avoid OOM
|
||||
MAX_NUM_SEQS = 4
|
||||
|
||||
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "auto"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@@ -82,7 +84,9 @@ def test_models(
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
with vllm_runner(
|
||||
model, max_num_seqs=MAX_NUM_SEQS, attention_backend=ATTN_BACKEND
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
@@ -157,6 +161,7 @@ def test_chunked_prefill_with_parallel_sampling(
|
||||
# forces prefill chunks with decoding
|
||||
max_num_batched_tokens=MAX_NUM_SEQS * 3,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
attention_backend=ATTN_BACKEND,
|
||||
) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
@@ -301,7 +306,9 @@ def test_full_cuda_graph(
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
with vllm_runner(
|
||||
model, max_num_seqs=MAX_NUM_SEQS, attention_backend=ATTN_BACKEND
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
@@ -370,6 +377,7 @@ def _get_vllm_runner_params(
|
||||
"max_model_len": max_model_len,
|
||||
"tensor_parallel_size": tensor_parallel_size,
|
||||
"gpu_memory_utilization": 0.4,
|
||||
"attention_backend": ATTN_BACKEND,
|
||||
}
|
||||
|
||||
|
||||
@@ -844,6 +852,7 @@ def test_apc_common_prefix_same_batch(
|
||||
mamba_block_size=16,
|
||||
enable_prefix_caching=True,
|
||||
seed=42,
|
||||
attention_backend=ATTN_BACKEND,
|
||||
)
|
||||
prompts = [
|
||||
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user