[Refactor] Pass tokenizer explicitly instead of binding to prompt update (#23542)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-25 21:31:57 +08:00
committed by GitHub
parent e269be2ba2
commit 6879cd80ae
4 changed files with 95 additions and 144 deletions

View File

@@ -243,7 +243,7 @@ def test_find_token_matches(
mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = {
key: update_type(key, target, []).resolve(mock_tokenizer, 0)
key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items()
}
result = {
@@ -392,7 +392,7 @@ def test_find_text_matches(
mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = {
key: update_type(key, target, []).resolve(mock_tokenizer, 0)
key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items()
}
result = {
@@ -559,10 +559,8 @@ def test_find_update_text(
) in expected_by_update_type_mm_count.items():
for mm_count, expected in expected_by_mm_count.items():
mm_prompt_updates = {
key: [[
update_type(key, target,
repl_by_key[key]).resolve(mock_tokenizer, i)
] for i in range(mm_count)]
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)]
for key, target in target_by_key.items()
}
@@ -731,10 +729,8 @@ def test_find_update_tokens(
) in expected_by_update_type_mm_count.items():
for mm_count, expected in expected_by_mm_count.items():
mm_prompt_updates = {
key: [[
update_type(key, target,
repl_by_key[key]).resolve(mock_tokenizer, i)
] for i in range(mm_count)]
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)]
for key, target in target_by_key.items()
}
@@ -879,12 +875,11 @@ def test_find_mm_placeholders(
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_updates = {
key: [[update_type(key, [], repl).resolve(mock_tokenizer, i)]
for i in range(3)]
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
for key, repl in repl_by_key.items()
}
result = find_mm_placeholders(prompt, mm_prompt_updates)
result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer)
# Only displayed on error
print("result:", result)