diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index 91f5e0290..a2eb5ff3a 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -426,7 +426,8 @@ class xpu_ops: mask = positions <= index_end_pos # mask: [B * N, L] logits = logits.masked_fill(~mask, float("-inf")) - topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K] + real_topk = min(topk_tokens, logits.shape[-1]) + topk_indices = logits.topk(real_topk, dim=-1)[1].to(torch.int32) # [B * N, K] # ensure we don't set indices for the top k # that is out of range(masked already) # this will happen if context length is shorter than K