[ROCm][CI] Extending attention backend coverage for Eagle spec decode tests (#35265)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -30,7 +30,7 @@ steps:
|
||||
- pytest -v -s v1/engine --ignore v1/engine/test_preprocess_error_handling.py
|
||||
mirror:
|
||||
amd:
|
||||
device: mi325_8
|
||||
device: mi325_1
|
||||
depends_on:
|
||||
- image-build-amd
|
||||
commands:
|
||||
|
||||
@@ -1327,6 +1327,57 @@ def multi_gpu_test(*, num_gpus: int):
|
||||
return wrapper
|
||||
|
||||
|
||||
def gpu_tier_mark(*, min_gpus: int = 1, max_gpus: int | None = None):
|
||||
"""
|
||||
Mark a test to only run when the GPU count falls within [min_gpus, max_gpus].
|
||||
|
||||
Examples:
|
||||
@gpu_tier_mark(min_gpus=2) # only on multi-GPU
|
||||
@gpu_tier_mark(max_gpus=1) # only on single-GPU
|
||||
@gpu_tier_mark(min_gpus=2, max_gpus=4) # 2-4 GPUs only
|
||||
"""
|
||||
gpu_count = cuda_device_count_stateless()
|
||||
marks = []
|
||||
|
||||
if min_gpus > 1:
|
||||
marks.append(pytest.mark.distributed(num_gpus=min_gpus))
|
||||
|
||||
reasons = []
|
||||
if gpu_count < min_gpus:
|
||||
reasons.append(f"Need at least {min_gpus} GPUs (have {gpu_count})")
|
||||
if max_gpus is not None and gpu_count > max_gpus:
|
||||
reasons.append(f"Need at most {max_gpus} GPUs (have {gpu_count})")
|
||||
|
||||
if reasons:
|
||||
marks.append(pytest.mark.skipif(True, reason="; ".join(reasons)))
|
||||
|
||||
return marks
|
||||
|
||||
|
||||
def single_gpu_only(f=None):
|
||||
"""Skip this test when running in a multi-GPU environment."""
|
||||
marks = gpu_tier_mark(max_gpus=1)
|
||||
|
||||
def wrapper(func):
|
||||
for mark in reversed(marks):
|
||||
func = mark(func)
|
||||
return func
|
||||
|
||||
return wrapper(f) if f is not None else wrapper
|
||||
|
||||
|
||||
def multi_gpu_only(*, num_gpus: int = 2):
|
||||
"""Skip this test when running on fewer than num_gpus GPUs."""
|
||||
marks = gpu_tier_mark(min_gpus=num_gpus)
|
||||
|
||||
def wrapper(f):
|
||||
for mark in reversed(marks):
|
||||
f = mark(f)
|
||||
return f
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
async def completions_with_server_args(
|
||||
prompts: list[str],
|
||||
model_name: str,
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
import pytest
|
||||
import torch._dynamo.config as dynamo_config
|
||||
|
||||
from tests.utils import large_gpu_mark, single_gpu_only
|
||||
from vllm import SamplingParams
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.platforms import current_platform
|
||||
@@ -36,6 +37,7 @@ default_params = dict(
|
||||
)
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
def test_without_spec_decoding(
|
||||
sample_json_schema,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -95,6 +97,8 @@ def test_without_spec_decoding(
|
||||
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=16)
|
||||
def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test consistency and acceptance rates with some different combos of
|
||||
preemption, executor, async scheduling, prefill chunking,
|
||||
|
||||
@@ -9,7 +9,13 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
|
||||
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
|
||||
from tests.utils import (
|
||||
get_attn_backend_list_based_on_platform,
|
||||
large_gpu_mark,
|
||||
multi_gpu_marks,
|
||||
multi_gpu_only,
|
||||
single_gpu_only,
|
||||
)
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
@@ -160,6 +166,8 @@ def reset_torch_dynamo():
|
||||
},
|
||||
],
|
||||
)
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=20)
|
||||
def test_ngram_and_suffix_correctness(
|
||||
speculative_config: dict,
|
||||
model_name: str,
|
||||
@@ -175,6 +183,8 @@ def test_ngram_and_suffix_correctness(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=20)
|
||||
def test_suffix_decoding_acceptance(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
@@ -242,6 +252,8 @@ def test_suffix_decoding_acceptance(
|
||||
],
|
||||
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
|
||||
)
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=24)
|
||||
def test_speculators_model_integration(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
@@ -319,137 +331,7 @@ def test_speculators_model_integration(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
[
|
||||
"model_setup",
|
||||
"mm_enabled",
|
||||
"enable_chunked_prefill",
|
||||
"model_impl",
|
||||
"expected_accuracy_threshold",
|
||||
],
|
||||
[
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8, # ref: 90%
|
||||
),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"transformers",
|
||||
0.8, # ref: 90%
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen3-VL-8B-Instruct",
|
||||
"taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8, # ref: 90%
|
||||
marks=pytest.mark.skip(
|
||||
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"Rayzl/qwen2.5-vl-7b-eagle3-sgl",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.7, # TODO, update this with a reference value when re-enabling this case
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a a multiple of 32"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
True,
|
||||
"auto",
|
||||
0.7, # ref: 75%-80%
|
||||
marks=large_gpu_mark(min_gb=40),
|
||||
), # works on 4x H100
|
||||
(
|
||||
(
|
||||
"eagle3",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.7, # ref: 75%-80%
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8, # ref: 90%
|
||||
# marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
True,
|
||||
True,
|
||||
"auto",
|
||||
0.8, # ref: 90%
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
(
|
||||
(
|
||||
"eagle",
|
||||
"eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.0, # dummy model, skip gsm8k check
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3",
|
||||
"qwen3_eagle3-transformers",
|
||||
"qwen3_vl_eagle3",
|
||||
"qwen2_5_vl_eagle3",
|
||||
"llama3_eagle",
|
||||
"llama3_eagle3",
|
||||
"llama4_eagle",
|
||||
"llama4_eagle_mm",
|
||||
"deepseek_eagle",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness(
|
||||
def _run_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
@@ -460,14 +342,10 @@ def test_eagle_correctness(
|
||||
attn_backend: str,
|
||||
):
|
||||
"""
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
which should be the same when using eagle speculative decoding. Due to some variance
|
||||
in the engine, it is possible for some outputs to differ, so we expect that at least
|
||||
6/10 output tokens match exactly, and that the GSM8k accuracy is above
|
||||
a precomputed reference threshold for each model.
|
||||
Compare the outputs of an original LLM and a speculative LLM
|
||||
which should be the same when using eagle speculative decoding.
|
||||
"""
|
||||
if attn_backend == "TREE_ATTN":
|
||||
# TODO: Fix this flaky test
|
||||
pytest.skip(
|
||||
"TREE_ATTN is flaky in the test disable for now until it can be "
|
||||
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
|
||||
@@ -484,17 +362,17 @@ def test_eagle_correctness(
|
||||
f"transformers>={required}, but got {installed}"
|
||||
)
|
||||
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
# Determine attention config
|
||||
# Scout requires default backend selection because vision encoder has
|
||||
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
|
||||
# to Flex Attn
|
||||
|
||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||
if current_platform.is_rocm():
|
||||
# TODO: Enable Flex Attn for spec_decode on ROCm
|
||||
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
|
||||
attention_config = None # Let it fall back to default
|
||||
print(
|
||||
"FLASH_ATTN for spec_decode not supported on "
|
||||
"ROCm currently. Changing to FLEX_ATTENTION backend."
|
||||
)
|
||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||
else:
|
||||
attention_config = None
|
||||
else:
|
||||
attention_config = {"backend": attn_backend}
|
||||
|
||||
@@ -509,7 +387,9 @@ def test_eagle_correctness(
|
||||
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
if "deepseek" in model_setup[1].lower():
|
||||
pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
m.delenv("VLLM_MLA_DISABLE", raising=False)
|
||||
attention_config = {"backend": "TRITON_MLA"}
|
||||
else:
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
@@ -563,14 +443,235 @@ def test_eagle_correctness(
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 60% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.6 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
@pytest.mark.parametrize(
|
||||
[
|
||||
"model_setup",
|
||||
"mm_enabled",
|
||||
"enable_chunked_prefill",
|
||||
"model_impl",
|
||||
"expected_accuracy_threshold",
|
||||
],
|
||||
[
|
||||
(
|
||||
(
|
||||
"eagle",
|
||||
"eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.0,
|
||||
),
|
||||
],
|
||||
ids=["deepseek_eagle"],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness_light(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
expected_accuracy_threshold: float,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
_run_eagle_correctness(
|
||||
monkeypatch,
|
||||
sampling_config,
|
||||
model_setup,
|
||||
mm_enabled,
|
||||
expected_accuracy_threshold,
|
||||
enable_chunked_prefill,
|
||||
model_impl,
|
||||
attn_backend,
|
||||
)
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=24)
|
||||
@pytest.mark.parametrize(
|
||||
[
|
||||
"model_setup",
|
||||
"mm_enabled",
|
||||
"enable_chunked_prefill",
|
||||
"model_impl",
|
||||
"expected_accuracy_threshold",
|
||||
],
|
||||
[
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8,
|
||||
),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"transformers",
|
||||
0.8,
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen3-VL-8B-Instruct",
|
||||
"taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8,
|
||||
marks=pytest.mark.skip(
|
||||
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"Rayzl/qwen2.5-vl-7b-eagle3-sgl",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.7,
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a multiple of 32"
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
"eagle3",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.7,
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3",
|
||||
"qwen3_eagle3-transformers",
|
||||
"qwen3_vl_eagle3",
|
||||
"qwen2_5_vl_eagle3",
|
||||
"llama3_eagle3",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness_medium(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
expected_accuracy_threshold: float,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
_run_eagle_correctness(
|
||||
monkeypatch,
|
||||
sampling_config,
|
||||
model_setup,
|
||||
mm_enabled,
|
||||
expected_accuracy_threshold,
|
||||
enable_chunked_prefill,
|
||||
model_impl,
|
||||
attn_backend,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
[
|
||||
"model_setup",
|
||||
"mm_enabled",
|
||||
"enable_chunked_prefill",
|
||||
"model_impl",
|
||||
"expected_accuracy_threshold",
|
||||
],
|
||||
[
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
True,
|
||||
"auto",
|
||||
0.7,
|
||||
marks=large_gpu_mark(min_gb=40),
|
||||
id="llama3_eagle",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8,
|
||||
marks=multi_gpu_marks(num_gpus=4),
|
||||
id="llama4_eagle",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
True,
|
||||
True,
|
||||
"auto",
|
||||
0.8,
|
||||
marks=[*multi_gpu_marks(num_gpus=4), large_gpu_mark(min_gb=80)],
|
||||
id="llama4_eagle_mm",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness_heavy(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
expected_accuracy_threshold: float,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
_run_eagle_correctness(
|
||||
monkeypatch,
|
||||
sampling_config,
|
||||
model_setup,
|
||||
mm_enabled,
|
||||
expected_accuracy_threshold,
|
||||
enable_chunked_prefill,
|
||||
model_impl,
|
||||
attn_backend,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled", "expected_accuracy_threshold"],
|
||||
[
|
||||
@@ -579,6 +680,8 @@ def test_eagle_correctness(
|
||||
],
|
||||
ids=["mimo", "deepseek"],
|
||||
)
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=20)
|
||||
def test_mtp_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
@@ -694,11 +797,13 @@ cases = [
|
||||
|
||||
@pytest.mark.parametrize("args", cases)
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@single_gpu_only
|
||||
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
|
||||
args.enforce_eager = enforce_eager
|
||||
assert_draft_model_correctness(args)
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
def test_draft_model_realistic_example():
|
||||
args = ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
@@ -713,6 +818,7 @@ def test_draft_model_realistic_example():
|
||||
assert_draft_model_correctness(args)
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
def test_draft_model_parallel_drafting():
|
||||
args = ArgsTest(
|
||||
target_model="Qwen/Qwen3-1.7B",
|
||||
@@ -738,6 +844,7 @@ def test_draft_model_parallel_drafting():
|
||||
ids=["target_quantized", "draft_quantized"],
|
||||
)
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@single_gpu_only
|
||||
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
|
||||
tgt_model, draft_model = models
|
||||
sd_case = ArgsTest(
|
||||
@@ -749,6 +856,7 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
|
||||
assert_draft_model_correctness(sd_case)
|
||||
|
||||
|
||||
@multi_gpu_only(num_gpus=2)
|
||||
def test_draft_model_tensor_parallelism():
|
||||
"""Ensure spec decode works when running with TP > 1."""
|
||||
_skip_if_insufficient_gpus_for_tp(2)
|
||||
@@ -764,6 +872,7 @@ def test_draft_model_tensor_parallelism():
|
||||
assert_draft_model_correctness(sd_case)
|
||||
|
||||
|
||||
@multi_gpu_only(num_gpus=2)
|
||||
def test_draft_model_engine_args_tensor_parallelism():
|
||||
"""Ensure the vllm_config for the draft model is created correctly,
|
||||
and independently of the target model (quantization, TP, etc.)"""
|
||||
|
||||
Reference in New Issue
Block a user