[Bugfix] Fix tensor parallel for qwen2 classification model (#10297)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -21,14 +21,14 @@ def test_classification_models(
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||
hf_outputs = hf_model.classify(example_prompts)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
print(hf_outputs, vllm_outputs)
|
||||
|
||||
# check logits difference
|
||||
|
||||
Reference in New Issue
Block a user