[Bugfix] Fix ubatch wrapper num_tokens calculate (#33694)
Signed-off-by: jiangkuaixue123 <jiangxiaozhou111@163.com>
This commit is contained in:
@@ -412,9 +412,7 @@ class UBatchWrapper:
|
|||||||
|
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
slot_mapping = forward_context.slot_mapping
|
slot_mapping = forward_context.slot_mapping
|
||||||
num_tokens = (
|
num_tokens = sum(ubatch_slice.num_tokens for ubatch_slice in ubatch_slices)
|
||||||
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
|
|
||||||
) * 2
|
|
||||||
input_ids = kwargs["input_ids"]
|
input_ids = kwargs["input_ids"]
|
||||||
positions = kwargs["positions"]
|
positions = kwargs["positions"]
|
||||||
intermediate_tensors = kwargs["intermediate_tensors"]
|
intermediate_tensors = kwargs["intermediate_tensors"]
|
||||||
|
|||||||
Reference in New Issue
Block a user