[Bugfix](xpu): prevent “selected index k out of range” in TP decode path (#37259)

Signed-off-by: zhenzhao <zhenzhao@habana.ai>
This commit is contained in:
zhao, zhenhui
2026-03-17 19:14:07 +08:00
committed by GitHub
parent 9c7cab5ebb
commit 4af9ed21cb

View File

@@ -426,7 +426,8 @@ class xpu_ops:
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float("-inf"))
topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K]
real_topk = min(topk_tokens, logits.shape[-1])
topk_indices = logits.topk(real_topk, dim=-1)[1].to(torch.int32) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K