[Model] support minicpm3 (#8297)

Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
ywfang
2024-09-14 22:50:26 +08:00
committed by GitHub
parent 1ef0d2efd0
commit 8a0cf1ddc3
7 changed files with 282 additions and 38 deletions

View File

@@ -5,7 +5,8 @@ This tests bigger models and use half precision.
Run `pytest tests/models/test_big_models.py`.
"""
import pytest
import torch
from vllm.platforms import current_platform
from ...utils import check_outputs_equal
@@ -19,10 +20,12 @@ MODELS = [
# "Qwen/Qwen1.5-0.5B" # Broken,
]
if not current_platform.is_cpu():
# MiniCPM requires fused_moe which is not supported by CPU
MODELS.append("openbmb/MiniCPM3-4B")
#TODO: remove this after CPU float16 support ready
target_dtype = "float"
if torch.cuda.is_available():
target_dtype = "half"
target_dtype = "float" if current_platform.is_cpu() else "half"
@pytest.mark.parametrize("model", MODELS)
@@ -39,7 +42,7 @@ def test_models(
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
@@ -57,7 +60,7 @@ def test_model_print(
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.