[Bugfix] Fix handling of Tensorizer arguments for LoadConfig (#20643)
Signed-off-by: Sanger Steel <sangersteel@gmail.com>
This commit is contained in:
@@ -1003,41 +1003,27 @@ class EngineArgs:
|
||||
override_attention_dtype=self.override_attention_dtype,
|
||||
)
|
||||
|
||||
def valid_tensorizer_config_provided(self) -> bool:
|
||||
"""
|
||||
Checks if a parseable TensorizerConfig was passed to
|
||||
self.model_loader_extra_config. It first checks if the config passed
|
||||
is a dict or a TensorizerConfig object directly, and if the latter is
|
||||
true (by checking that the object has TensorizerConfig's
|
||||
.to_serializable() method), converts it in to a serializable dict
|
||||
format
|
||||
"""
|
||||
if self.model_loader_extra_config:
|
||||
if hasattr(self.model_loader_extra_config, "to_serializable"):
|
||||
self.model_loader_extra_config = (
|
||||
self.model_loader_extra_config.to_serializable())
|
||||
for allowed_to_pass in ["tensorizer_uri", "tensorizer_dir"]:
|
||||
try:
|
||||
self.model_loader_extra_config[allowed_to_pass]
|
||||
return False
|
||||
except KeyError:
|
||||
pass
|
||||
return True
|
||||
def validate_tensorizer_args(self):
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig)
|
||||
for key in self.model_loader_extra_config:
|
||||
if key in TensorizerConfig._fields:
|
||||
self.model_loader_extra_config["tensorizer_config"][
|
||||
key] = self.model_loader_extra_config[key]
|
||||
|
||||
def create_load_config(self) -> LoadConfig:
|
||||
|
||||
if self.quantization == "bitsandbytes":
|
||||
self.load_format = "bitsandbytes"
|
||||
|
||||
if (self.load_format == "tensorizer"
|
||||
and self.valid_tensorizer_config_provided()):
|
||||
logger.info("Inferring Tensorizer args from %s", self.model)
|
||||
self.model_loader_extra_config = {"tensorizer_dir": self.model}
|
||||
else:
|
||||
logger.info(
|
||||
"Using Tensorizer args from --model-loader-extra-config. "
|
||||
"Note that you can now simply pass the S3 directory in the "
|
||||
"model tag instead of providing the JSON string.")
|
||||
if self.load_format == "tensorizer":
|
||||
if hasattr(self.model_loader_extra_config, "to_serializable"):
|
||||
self.model_loader_extra_config = (
|
||||
self.model_loader_extra_config.to_serializable())
|
||||
self.model_loader_extra_config["tensorizer_config"] = {}
|
||||
self.model_loader_extra_config["tensorizer_config"][
|
||||
"tensorizer_dir"] = self.model
|
||||
self.validate_tensorizer_args()
|
||||
|
||||
return LoadConfig(
|
||||
load_format=self.load_format,
|
||||
|
||||
Reference in New Issue
Block a user