Added dtype arg to benchmarks (#1228)
This commit is contained in:
@@ -23,6 +23,7 @@ def main(args: argparse.Namespace):
|
||||
max_num_seqs=args.batch_size,
|
||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
dtype=args.dtype,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
@@ -87,5 +88,14 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user