[Bugfix](xpu): prevent “selected index k out of range” in TP decode path (#37259)
Signed-off-by: zhenzhao <zhenzhao@habana.ai>
This commit is contained in:
@@ -426,7 +426,8 @@ class xpu_ops:
|
||||
mask = positions <= index_end_pos
|
||||
# mask: [B * N, L]
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K]
|
||||
real_topk = min(topk_tokens, logits.shape[-1])
|
||||
topk_indices = logits.topk(real_topk, dim=-1)[1].to(torch.int32) # [B * N, K]
|
||||
# ensure we don't set indices for the top k
|
||||
# that is out of range(masked already)
|
||||
# this will happen if context length is shorter than K
|
||||
|
||||
Reference in New Issue
Block a user