[Bugfix] Fix ubatch wrapper num_tokens calculate (#33694)
Signed-off-by: jiangkuaixue123 <jiangxiaozhou111@163.com>
This commit is contained in:
@@ -412,9 +412,7 @@ class UBatchWrapper:
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
num_tokens = (
|
||||
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
|
||||
) * 2
|
||||
num_tokens = sum(ubatch_slice.num_tokens for ubatch_slice in ubatch_slices)
|
||||
input_ids = kwargs["input_ids"]
|
||||
positions = kwargs["positions"]
|
||||
intermediate_tensors = kwargs["intermediate_tensors"]
|
||||
|
||||
Reference in New Issue
Block a user