Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -18,9 +18,9 @@ logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class PEFTHelper:
|
||||
"""
|
||||
"""
|
||||
A helper class for PEFT configurations, specifically designed for LoRA.
|
||||
This class handles configuration validation, compatibility checks for
|
||||
This class handles configuration validation, compatibility checks for
|
||||
various LoRA implementations.
|
||||
"""
|
||||
|
||||
@@ -71,37 +71,38 @@ class PEFTHelper:
|
||||
# Identify any missing required fields
|
||||
missing_fields = required_fields - set(config_dict.keys())
|
||||
if missing_fields:
|
||||
raise ValueError(
|
||||
f"Missing required configuration fields: {missing_fields}")
|
||||
raise ValueError(f"Missing required configuration fields: {missing_fields}")
|
||||
|
||||
# Filter out fields that aren't defined in the class
|
||||
filtered_dict = {
|
||||
k: v
|
||||
for k, v in config_dict.items() if k in class_fields
|
||||
}
|
||||
filtered_dict = {k: v for k, v in config_dict.items() if k in class_fields}
|
||||
return cls(**filtered_dict)
|
||||
|
||||
@classmethod
|
||||
def from_local_dir(
|
||||
cls,
|
||||
lora_path: str,
|
||||
max_position_embeddings: Optional[int],
|
||||
tensorizer_config_dict: Optional[dict] = None) -> "PEFTHelper":
|
||||
cls,
|
||||
lora_path: str,
|
||||
max_position_embeddings: Optional[int],
|
||||
tensorizer_config_dict: Optional[dict] = None,
|
||||
) -> "PEFTHelper":
|
||||
lora_config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
|
||||
if tensorizer_config_dict:
|
||||
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
from tensorizer.stream_io import open_stream
|
||||
lora_config_path = os.path.join(tensorizer_config.tensorizer_dir,
|
||||
"adapter_config.json")
|
||||
with open_stream(lora_config_path,
|
||||
mode="rb",
|
||||
**tensorizer_args.stream_kwargs) as f:
|
||||
|
||||
lora_config_path = os.path.join(
|
||||
tensorizer_config.tensorizer_dir, "adapter_config.json"
|
||||
)
|
||||
with open_stream(
|
||||
lora_config_path, mode="rb", **tensorizer_args.stream_kwargs
|
||||
) as f:
|
||||
config = json.load(f)
|
||||
|
||||
logger.info("Successfully deserialized LoRA config from %s",
|
||||
tensorizer_config.tensorizer_dir)
|
||||
logger.info(
|
||||
"Successfully deserialized LoRA config from %s",
|
||||
tensorizer_config.tensorizer_dir,
|
||||
)
|
||||
|
||||
else:
|
||||
with open(lora_config_path) as f:
|
||||
@@ -112,16 +113,16 @@ class PEFTHelper:
|
||||
|
||||
def validate_legal(self, lora_config: LoRAConfig) -> None:
|
||||
"""
|
||||
Validates the LoRA configuration settings against application
|
||||
Validates the LoRA configuration settings against application
|
||||
constraints and requirements.
|
||||
"""
|
||||
error_msg = self._validate_features()
|
||||
if self.r > lora_config.max_lora_rank:
|
||||
error_msg.append(
|
||||
f"LoRA rank {self.r} is greater than max_lora_rank"
|
||||
f" {lora_config.max_lora_rank}.")
|
||||
f" {lora_config.max_lora_rank}."
|
||||
)
|
||||
if self.bias != "none" and not lora_config.bias_enabled:
|
||||
error_msg.append(
|
||||
"Adapter bias cannot be used without bias_enabled.")
|
||||
error_msg.append("Adapter bias cannot be used without bias_enabled.")
|
||||
if error_msg:
|
||||
raise ValueError(f"{' '.join(error_msg)}")
|
||||
|
||||
Reference in New Issue
Block a user