[TPU] Add Load-time W8A16 quantization for TPU Backend (#7005)
This commit is contained in:
@@ -244,6 +244,7 @@ class ModelConfig:
|
||||
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
|
||||
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
|
||||
]
|
||||
tpu_supported_quantization = ["tpu_int8"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
@@ -282,6 +283,11 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in ROCm.")
|
||||
if is_tpu(
|
||||
) and self.quantization not in tpu_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in TPU Backend.")
|
||||
if self.quantization not in optimized_quantization_methods:
|
||||
logger.warning(
|
||||
"%s quantization is not fully "
|
||||
|
||||
Reference in New Issue
Block a user