[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)
This commit is contained in:
@@ -18,8 +18,10 @@ from ..utils import fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("model, distributed_executor_backend", [
|
||||
("llava-hf/llava-1.5-7b-hf", "ray"),
|
||||
("llava-hf/llava-v1.6-mistral-7b-hf", "ray"),
|
||||
("facebook/chameleon-7b", "ray"),
|
||||
("llava-hf/llava-1.5-7b-hf", "mp"),
|
||||
("llava-hf/llava-v1.6-mistral-7b-hf", "mp"),
|
||||
("facebook/chameleon-7b", "mp"),
|
||||
])
|
||||
@fork_new_process_for_each_test
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model: str,
|
||||
@@ -34,6 +36,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
|
||||
from ..models.test_llava import models, run_test
|
||||
elif model.startswith("llava-hf/llava-v1.6"):
|
||||
from ..models.test_llava_next import models, run_test
|
||||
elif model.startswith("facebook/chameleon"):
|
||||
from ..models.test_chameleon import models, run_test
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported model: {model}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user