Fix ExaoneMoeMTP test that never ran in Transformers v4 (#36792)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-11 17:10:23 +00:00
committed by GitHub
parent 196802dfa6
commit 5efa206a8c
4 changed files with 17 additions and 0 deletions

View File

@@ -247,6 +247,7 @@ def _compare_tp(
hf_config = get_config(model_id, trust_remote_code)
require_embed_inputs = model_info.require_embed_inputs
max_num_seqs = model_info.max_num_seqs
enable_prefix_caching = model_info.enable_prefix_caching
dtype = "float16"
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
@@ -300,6 +301,8 @@ def _compare_tp(
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if not enable_prefix_caching:
common_args.append("--no-enable-prefix-caching")
if require_embed_inputs:
common_args.extend(
[

View File

@@ -74,6 +74,8 @@ def run_test(
if model_info.require_embed_inputs:
for k in ("skip_tokenizer_init", "enable_prompt_embeds", "enable_mm_embeds"):
vllm_runner_kwargs_[k] = model_info.require_embed_inputs
if not model_info.enable_prefix_caching:
vllm_runner_kwargs_["enable_prefix_caching"] = False
if vllm_runner_kwargs:
vllm_runner_kwargs_.update(vllm_runner_kwargs)

View File

@@ -72,6 +72,12 @@ class _HfExamplesInfo:
If False, we will use CUDA graph and eager execution in hybrid.
"""
enable_prefix_caching: bool = True
"""
Whether to enable prefix caching for the model. If True, we will test the model with
prefix caching enabled. If False, we will test the model without prefix caching.
"""
is_available_online: bool = True
"""
Set this to `False` if the name of this architecture no longer exists on
@@ -1206,6 +1212,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"LGAI-EXAONE/K-EXAONE-236B-A23B",
speculative_model="LGAI-EXAONE/K-EXAONE-236B-A23B",
min_transformers_version="5.1.0",
enable_prefix_caching=False,
),
"ExtractHiddenStatesModel": _HfExamplesInfo(
"Qwen/Qwen3-8B",

View File

@@ -136,6 +136,10 @@ def can_initialize(
if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
kwargs = {}
if not model_info.enable_prefix_caching:
kwargs["enable_prefix_caching"] = False
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
@@ -165,6 +169,7 @@ def can_initialize(
hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs,
attention_config=attention_config,
**kwargs,
)