Improve literal dataclass field conversion to argparse argument (#17391)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -116,6 +116,18 @@ def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
|
||||
return next((th for th in type_hints if is_type(th, type)), None)
|
||||
|
||||
|
||||
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
|
||||
"""Convert Literal type hints to argparse kwargs."""
|
||||
type_hint = get_type(type_hints, Literal)
|
||||
choices = get_args(type_hint)
|
||||
choice_type = type(choices[0])
|
||||
if not all(isinstance(choice, choice_type) for choice in choices):
|
||||
raise ValueError(
|
||||
"All choices must be of the same type. "
|
||||
f"Got {choices} with types {[type(c) for c in choices]}")
|
||||
return {"type": choice_type, "choices": sorted(choices)}
|
||||
|
||||
|
||||
def is_not_builtin(type_hint: TypeHint) -> bool:
|
||||
"""Check if the class is not a built-in type."""
|
||||
return type_hint.__module__ != "builtins"
|
||||
@@ -151,15 +163,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
# Creates --no-<name> and --<name> flags
|
||||
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
||||
elif contains_type(type_hints, Literal):
|
||||
# Creates choices from Literal arguments
|
||||
type_hint = get_type(type_hints, Literal)
|
||||
choices = sorted(get_args(type_hint))
|
||||
kwargs[name]["choices"] = choices
|
||||
choice_type = type(choices[0])
|
||||
assert all(type(c) is choice_type for c in choices), (
|
||||
"All choices must be of the same type. "
|
||||
f"Got {choices} with types {[type(c) for c in choices]}")
|
||||
kwargs[name]["type"] = choice_type
|
||||
kwargs[name].update(literal_to_kwargs(type_hints))
|
||||
elif contains_type(type_hints, tuple):
|
||||
type_hint = get_type(type_hints, tuple)
|
||||
types = get_args(type_hint)
|
||||
@@ -191,6 +195,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
raise ValueError(
|
||||
f"Unsupported type {type_hints} for argument {name}.")
|
||||
|
||||
# If the type hint was a sequence of literals, use the helper function
|
||||
# to update the type and choices
|
||||
if get_origin(kwargs[name].get("type")) is Literal:
|
||||
kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))
|
||||
|
||||
# If None is in type_hints, make the argument optional.
|
||||
# But not if it's a bool, argparse will handle this better.
|
||||
if type(None) in type_hints and not contains_type(type_hints, bool):
|
||||
|
||||
Reference in New Issue
Block a user