[Bugfix] Fix load config when using bools (#9533)
This commit is contained in:
@@ -1155,6 +1155,18 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
return wrapper
|
||||
|
||||
|
||||
class StoreBoolean(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if values.lower() == "true":
|
||||
setattr(namespace, self.dest, True)
|
||||
elif values.lower() == "false":
|
||||
setattr(namespace, self.dest, False)
|
||||
else:
|
||||
raise ValueError(f"Invalid boolean value: {values}. "
|
||||
"Expected 'true' or 'false'.")
|
||||
|
||||
|
||||
class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
"""ArgumentParser that allows both underscore and dash in names."""
|
||||
|
||||
@@ -1163,7 +1175,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
args = sys.argv[1:]
|
||||
|
||||
if '--config' in args:
|
||||
args = FlexibleArgumentParser._pull_args_from_config(args)
|
||||
args = self._pull_args_from_config(args)
|
||||
|
||||
# Convert underscores to dashes and vice versa in argument names
|
||||
processed_args = []
|
||||
@@ -1181,8 +1193,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
|
||||
return super().parse_args(processed_args, namespace)
|
||||
|
||||
@staticmethod
|
||||
def _pull_args_from_config(args: List[str]) -> List[str]:
|
||||
def _pull_args_from_config(self, args: List[str]) -> List[str]:
|
||||
"""Method to pull arguments specified in the config file
|
||||
into the command-line args variable.
|
||||
|
||||
@@ -1226,7 +1237,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
|
||||
file_path = args[index + 1]
|
||||
|
||||
config_args = FlexibleArgumentParser._load_config_file(file_path)
|
||||
config_args = self._load_config_file(file_path)
|
||||
|
||||
# 0th index is for {serve,chat,complete}
|
||||
# followed by model_tag (only for serve)
|
||||
@@ -1247,8 +1258,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def _load_config_file(file_path: str) -> List[str]:
|
||||
def _load_config_file(self, file_path: str) -> List[str]:
|
||||
"""Loads a yaml file and returns the key value pairs as a
|
||||
flattened list with argparse like pattern
|
||||
```yaml
|
||||
@@ -1282,9 +1292,18 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
Make sure path is correct", file_path)
|
||||
raise ex
|
||||
|
||||
store_boolean_arguments = [
|
||||
action.dest for action in self._actions
|
||||
if isinstance(action, StoreBoolean)
|
||||
]
|
||||
|
||||
for key, value in config.items():
|
||||
processed_args.append('--' + key)
|
||||
processed_args.append(str(value))
|
||||
if isinstance(value, bool) and key not in store_boolean_arguments:
|
||||
if value:
|
||||
processed_args.append('--' + key)
|
||||
else:
|
||||
processed_args.append('--' + key)
|
||||
processed_args.append(str(value))
|
||||
|
||||
return processed_args
|
||||
|
||||
|
||||
Reference in New Issue
Block a user