[Bugfix] Fix ubatch wrapper num_tokens calculate (#33694)

Signed-off-by: jiangkuaixue123 <jiangxiaozhou111@163.com>
This commit is contained in:
jiangkuaixue123
2026-02-05 00:41:45 +08:00
committed by GitHub
parent 80f921ba4b
commit 87d9a26166

View File

@@ -412,9 +412,7 @@ class UBatchWrapper:
attn_metadata = forward_context.attn_metadata
slot_mapping = forward_context.slot_mapping
num_tokens = (
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
) * 2
num_tokens = sum(ubatch_slice.num_tokens for ubatch_slice in ubatch_slices)
input_ids = kwargs["input_ids"]
positions = kwargs["positions"]
intermediate_tensors = kwargs["intermediate_tensors"]