Fix PyTorch version to 2.0.1 in workflow (#1377)

This commit is contained in:
Woosuk Kwon
2023-10-16 11:27:17 -07:00
committed by GitHub
parent 9d9072a069
commit 348897af31
2 changed files with 6 additions and 4 deletions

View File

@@ -1,11 +1,12 @@
#!/bin/bash
python_executable=python$1
cuda_version=$2
pytorch_version=$2
cuda_version=$3
# Install torch
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
$python_executable -m pip install torch -f https://download.pytorch.org/whl/cu${cuda_version//./}/torch_stable.html
$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --index-url https://download.pytorch.org/whl/cu${cuda_version//./}
# Print version information
$python_executable --version