[Doc] Add troubleshooting for Triton PTX error about undefined gpu-name (#31338)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Isotr0py
2025-12-25 18:26:34 +08:00
committed by GitHub
parent f15185fbdb
commit 2532f437ee

View File

@@ -320,6 +320,32 @@ This indicates vLLM failed to initialize the NCCL communicator, possibly due to
If you see an error like `RuntimeError: CUDA error: the provided PTX was compiled with an unsupported toolchain.`, it means that the CUDA PTX in vLLM's wheels was compiled with a toolchain unsupported by your system. The released vLLM wheels have to be compiled with a specific version of CUDA toolkit, and the compiled code might fail to run on lower versions of CUDA drivers. Read [cuda compatibility](https://docs.nvidia.com/deploy/cuda-compatibility/) for more details. The solution is to install `cuda-compat` package from your package manager. For example, on Ubuntu, you can run `sudo apt-get install cuda-compat-12-9`, and then add `export LD_LIBRARY_PATH=/usr/local/cuda-12.9/compat:$LD_LIBRARY_PATH` to your `.bashrc` file. When successfully installed, you should see that the output of `nvidia-smi` will show `CUDA Version: 12.9`. Note that we use CUDA 12.9 as an example here, you may want to install a higher version of cuda-compat package in case vLLM's default CUDA version goes higher.
## ptxas fatal: Value 'sm_110a' is not defined for option 'gpu-name'
If you use triton kernels with cuda 13, you might see an error like `ptxas fatal: Value 'sm_110a' is not defined for option 'gpu-name'`:
```text
(EngineCore_0 pid=9492) triton.runtime.errors.PTXASError: PTXAS error: Internal Triton PTX codegen error
(EngineCore_0 pid=9492) `ptxas` stderr:
(EngineCore_0 pid=9492) ptxas fatal : Value 'sm_110a' is not defined for option 'gpu-name'
(EngineCore_0 pid=9492)
(EngineCore_0 pid=9492) Repro command: /home/jetson/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas -lineinfo -v --gpu-name=sm_110a /tmp/tmp95oy_b9d.ptx -o /tmp/tmp95oy_b9d.ptx.o
(EngineCore_0 pid=9492)
outputs = self.engine_core.get_output()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/.venv/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 668, in get_output
raise self._format_exception(outputs) from None
vllm.v1.engine.exceptions.EngineDeadError: EngineCore encountered an issue. See stack trace (above) for the root cause.
```
It means that the ptxas in triton bundle not compatible with your device. You need to set `TRITON_PTXAS_PATH` environment variable to use cuda toolkit's ptxas manually instead:
```shell
export CUDA_HOME=/usr/local/cuda
export TRITON_PTXAS_PATH="${CUDA_HOME}/bin/ptxas"
export PATH="${CUDA_HOME}/bin:$PATH"
```
## Known Issues
- In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](https://github.com/vllm-project/vllm/pull/6759).