[Model][2/N] Automatic conversion of CrossEncoding model (#19978)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -234,6 +234,35 @@ def run_mteb_rerank(cross_encoder, tasks, languages):
|
||||
return main_score
|
||||
|
||||
|
||||
def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None):
|
||||
with hf_runner(model_name, is_cross_encoder=True,
|
||||
dtype="float32") as hf_model:
|
||||
|
||||
original_predict = hf_model.predict
|
||||
|
||||
def _predict(
|
||||
sentences: list[tuple[str, str,
|
||||
Optional[str]]], # query, corpus, prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# vllm and st both remove the prompt, fair comparison.
|
||||
prompts = [(s[0], s[1]) for s in sentences]
|
||||
return original_predict(prompts, *args, **kwargs, batch_size=8)
|
||||
|
||||
hf_model.predict = _predict
|
||||
hf_model.original_predict = original_predict
|
||||
|
||||
if hf_model_callback is not None:
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
st_main_score = run_mteb_rerank(hf_model,
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
st_dtype = next(hf_model.model.model.parameters()).dtype
|
||||
return st_main_score, st_dtype
|
||||
|
||||
|
||||
def mteb_test_rerank_models(hf_runner,
|
||||
vllm_runner,
|
||||
model_info: RerankModelInfo,
|
||||
@@ -264,31 +293,8 @@ def mteb_test_rerank_models(hf_runner,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
vllm_dtype = model_config.dtype
|
||||
|
||||
with hf_runner(model_info.name, is_cross_encoder=True,
|
||||
dtype="float32") as hf_model:
|
||||
|
||||
original_predict = hf_model.predict
|
||||
|
||||
def _predict(
|
||||
sentences: list[tuple[str, str,
|
||||
Optional[str]]], # query, corpus, prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# vllm and st both remove the prompt, fair comparison.
|
||||
prompts = [(s[0], s[1]) for s in sentences]
|
||||
return original_predict(prompts, *args, **kwargs, batch_size=8)
|
||||
|
||||
hf_model.predict = _predict
|
||||
hf_model.original_predict = original_predict
|
||||
|
||||
if hf_model_callback is not None:
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
st_main_score = run_mteb_rerank(hf_model,
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
st_dtype = next(hf_model.model.model.parameters()).dtype
|
||||
st_main_score, st_dtype = mteb_test_rerank_models_hf(
|
||||
hf_runner, model_info.name, hf_model_callback)
|
||||
|
||||
print("VLLM:", vllm_dtype, vllm_main_score)
|
||||
print("SentenceTransformers:", st_dtype, st_main_score)
|
||||
|
||||
Reference in New Issue
Block a user