Adds method to read the pooling types from model's files (#9506)

Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Flávia Béo
2024-11-07 05:42:40 -03:00
committed by GitHub
parent e036e527a0
commit aa9078fa03
10 changed files with 342 additions and 25 deletions

View File

@@ -30,6 +30,13 @@ def test_limit_mm_per_prompt_parser(arg, expected):
assert args.limit_mm_per_prompt == expected
def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args(["--pooling-type=MEAN"])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.pooling_type == 'MEAN'
@pytest.mark.parametrize(
("arg"),
[