Added dtype arg to benchmarks (#1228)

This commit is contained in:
kg6-sleipnir
2023-10-01 00:04:03 -04:00
committed by GitHub
parent 0967102c6d
commit b5a10eb0ef
2 changed files with 22 additions and 1 deletions

View File

@@ -23,6 +23,7 @@ def main(args: argparse.Namespace):
max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
)
sampling_params = SamplingParams(
@@ -87,5 +88,14 @@ if __name__ == '__main__':
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
args = parser.parse_args()
main(args)