Fix apply_top_k_top_p_triton called by non-cuda logits Tensor (#35030)

Signed-off-by: Xiao Li <ilx@meta.com>
This commit is contained in:
Xiao Li
2026-02-21 21:11:54 -08:00
committed by GitHub
parent cbd95a2dd1
commit 30132cd144

View File

@@ -248,7 +248,7 @@ def apply_top_k_top_p(
if p is None and k is None:
return logits
if HAS_TRITON and logits.shape[0] >= 8:
if HAS_TRITON and logits.shape[0] >= 8 and logits.is_cuda:
return apply_top_k_top_p_triton(logits, k, p)
# Use pytorch sort implementation for small batch sizes.