[ModelRunner V2] Revert token rank comparison difference for now (#34017)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -68,7 +68,7 @@ def _ranks_kernel(
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
n += tl.sum((logits > x).to(tl.int32))
|
||||
n += tl.sum((logits >= x).to(tl.int32))
|
||||
tl.store(output_ptr + req_idx, n)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user