Adds method to read the pooling types from model's files (#9506)

Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Flávia Béo
2024-11-07 05:42:40 -03:00
committed by GitHub
parent e036e527a0
commit aa9078fa03
10 changed files with 342 additions and 25 deletions

View File

@@ -230,7 +230,7 @@ def quantize_model(model, quant_cfg, calib_dataloader=None):
def main(args):
if not torch.cuda.is_available():
raise EnvironmentError("GPU is required for inference.")
raise OSError("GPU is required for inference.")
random.seed(RAND_SEED)
np.random.seed(RAND_SEED)
@@ -314,7 +314,7 @@ def main(args):
# Workaround for wo quantization
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
with open(f"{export_path}/config.json", 'r') as f:
with open(f"{export_path}/config.json") as f:
tensorrt_llm_config = json.load(f)
if args.qformat == "int8_wo":
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'