Fix apply_top_k_top_p_triton called by non-cuda logits Tensor (#35030)
Signed-off-by: Xiao Li <ilx@meta.com>
This commit is contained in:
@@ -248,7 +248,7 @@ def apply_top_k_top_p(
|
||||
if p is None and k is None:
|
||||
return logits
|
||||
|
||||
if HAS_TRITON and logits.shape[0] >= 8:
|
||||
if HAS_TRITON and logits.shape[0] >= 8 and logits.is_cuda:
|
||||
return apply_top_k_top_p_triton(logits, k, p)
|
||||
|
||||
# Use pytorch sort implementation for small batch sizes.
|
||||
|
||||
Reference in New Issue
Block a user