Adds method to read the pooling types from model's files (#9506)
Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -230,7 +230,7 @@ def quantize_model(model, quant_cfg, calib_dataloader=None):
|
||||
|
||||
def main(args):
|
||||
if not torch.cuda.is_available():
|
||||
raise EnvironmentError("GPU is required for inference.")
|
||||
raise OSError("GPU is required for inference.")
|
||||
|
||||
random.seed(RAND_SEED)
|
||||
np.random.seed(RAND_SEED)
|
||||
@@ -314,7 +314,7 @@ def main(args):
|
||||
|
||||
# Workaround for wo quantization
|
||||
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
|
||||
with open(f"{export_path}/config.json", 'r') as f:
|
||||
with open(f"{export_path}/config.json") as f:
|
||||
tensorrt_llm_config = json.load(f)
|
||||
if args.qformat == "int8_wo":
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
|
||||
|
||||
Reference in New Issue
Block a user