Add pt_load_map_location to allow loading to cuda (#16869)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
@@ -64,6 +64,13 @@ def optional_type(
|
||||
return _optional_type
|
||||
|
||||
|
||||
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
|
||||
if not re.match("^{.*}$", val):
|
||||
return str(val)
|
||||
else:
|
||||
return optional_type(json.loads)(val)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Passing a JSON argument as a string containing comma separated key=value "
|
||||
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
|
||||
@@ -187,6 +194,10 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
kwargs[name]["type"] = human_readable_int
|
||||
elif contains_type(type_hints, float):
|
||||
kwargs[name]["type"] = float
|
||||
elif contains_type(type_hints,
|
||||
dict) and (contains_type(type_hints, str) or any(
|
||||
is_not_builtin(th) for th in type_hints)):
|
||||
kwargs[name]["type"] = union_dict_and_str
|
||||
elif contains_type(type_hints, dict):
|
||||
# Dict arguments will always be optional
|
||||
kwargs[name]["type"] = optional_type(json.loads)
|
||||
@@ -371,6 +382,7 @@ class EngineArgs:
|
||||
reasoning_parser: str = DecodingConfig.reasoning_backend
|
||||
|
||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||
pt_load_map_location: str = LoadConfig.pt_load_map_location
|
||||
|
||||
def __post_init__(self):
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
@@ -491,6 +503,8 @@ class EngineArgs:
|
||||
type=str,
|
||||
default=None,
|
||||
help='Name or path of the QLoRA adapter.')
|
||||
load_group.add_argument('--pt-load-map-location',
|
||||
**load_kwargs["pt_load_map_location"])
|
||||
|
||||
# Guided decoding arguments
|
||||
guided_decoding_kwargs = get_kwargs(DecodingConfig)
|
||||
@@ -883,12 +897,14 @@ class EngineArgs:
|
||||
|
||||
if self.quantization == "bitsandbytes":
|
||||
self.load_format = "bitsandbytes"
|
||||
|
||||
return LoadConfig(
|
||||
load_format=self.load_format,
|
||||
download_dir=self.download_dir,
|
||||
model_loader_extra_config=self.model_loader_extra_config,
|
||||
ignore_patterns=self.ignore_patterns,
|
||||
use_tqdm_on_load=self.use_tqdm_on_load,
|
||||
pt_load_map_location=self.pt_load_map_location,
|
||||
)
|
||||
|
||||
def create_speculative_config(
|
||||
@@ -1513,7 +1529,7 @@ def _warn_or_fallback(feature_name: str) -> bool:
|
||||
def human_readable_int(value):
|
||||
"""Parse human-readable integers like '1k', '2M', etc.
|
||||
Including decimal values with decimal multipliers.
|
||||
|
||||
|
||||
Examples:
|
||||
- '1k' -> 1,000
|
||||
- '1K' -> 1,024
|
||||
|
||||
Reference in New Issue
Block a user