Improve the output precision of embedding models (#19092)
This commit is contained in:
@@ -56,14 +56,10 @@ def correctness_test_embed_models(hf_runner,
|
||||
max_model_len=None,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
|
||||
model_dtype = getattr(
|
||||
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
|
||||
vllm_dtype)
|
||||
|
||||
with hf_runner(
|
||||
model_info.name,
|
||||
dtype=model_dtype,
|
||||
dtype="float32",
|
||||
is_sentence_transformer=True,
|
||||
) as hf_model:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user