Add pt_load_map_location to allow loading to cuda (#16869)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang
2025-05-01 23:23:42 -07:00
committed by GitHub
parent f192ca90e6
commit 109e15a335
6 changed files with 74 additions and 3 deletions

View File

@@ -64,6 +64,13 @@ def optional_type(
return _optional_type
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
if not re.match("^{.*}$", val):
return str(val)
else:
return optional_type(json.loads)(val)
@deprecated(
"Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
@@ -187,6 +194,10 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["type"] = human_readable_int
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif contains_type(type_hints,
dict) and (contains_type(type_hints, str) or any(
is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = union_dict_and_str
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
@@ -371,6 +382,7 @@ class EngineArgs:
reasoning_parser: str = DecodingConfig.reasoning_backend
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
pt_load_map_location: str = LoadConfig.pt_load_map_location
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
@@ -491,6 +503,8 @@ class EngineArgs:
type=str,
default=None,
help='Name or path of the QLoRA adapter.')
load_group.add_argument('--pt-load-map-location',
**load_kwargs["pt_load_map_location"])
# Guided decoding arguments
guided_decoding_kwargs = get_kwargs(DecodingConfig)
@@ -883,12 +897,14 @@ class EngineArgs:
if self.quantization == "bitsandbytes":
self.load_format = "bitsandbytes"
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=self.use_tqdm_on_load,
pt_load_map_location=self.pt_load_map_location,
)
def create_speculative_config(
@@ -1513,7 +1529,7 @@ def _warn_or_fallback(feature_name: str) -> bool:
def human_readable_int(value):
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024