Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -21,7 +21,6 @@ class MockModelConfig:
class MockTokenizerResult:
def __init__(self, input_ids):
self.input_ids = input_ids
@@ -45,9 +44,11 @@ def mock_async_tokenizer():
@pytest.fixture
def renderer(mock_model_config, mock_tokenizer):
return CompletionRenderer(model_config=mock_model_config,
tokenizer=mock_tokenizer,
async_tokenizer_pool={})
return CompletionRenderer(
model_config=mock_model_config,
tokenizer=mock_tokenizer,
async_tokenizer_pool={},
)
class TestRenderPrompt:
@@ -57,7 +58,8 @@ class TestRenderPrompt:
async def test_token_input(self, renderer):
tokens = [101, 7592, 2088]
results = await renderer.render_prompt(
prompt_or_prompts=tokens, config=RenderConfig(max_length=100))
prompt_or_prompts=tokens, config=RenderConfig(max_length=100)
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
@@ -66,7 +68,8 @@ class TestRenderPrompt:
async def test_token_list_input(self, renderer):
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
results = await renderer.render_prompt(
prompt_or_prompts=token_lists, config=RenderConfig(max_length=100))
prompt_or_prompts=token_lists, config=RenderConfig(max_length=100)
)
assert len(results) == 3
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
@@ -75,14 +78,12 @@ class TestRenderPrompt:
@pytest.mark.asyncio
async def test_text_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088])
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
@@ -90,15 +91,13 @@ class TestRenderPrompt:
@pytest.mark.asyncio
async def test_text_list_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088])
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
text_list_input = ["Hello world", "How are you?", "Good morning"]
results = await renderer.render_prompt(
prompt_or_prompts=text_list_input,
config=RenderConfig(max_length=100))
prompt_or_prompts=text_list_input, config=RenderConfig(max_length=100)
)
assert len(results) == 3
for result in results:
@@ -107,31 +106,31 @@ class TestRenderPrompt:
@pytest.mark.asyncio
async def test_no_truncation(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088])
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert "truncation" not in call_args.kwargs or call_args.kwargs[
"truncation"] is False
assert (
"truncation" not in call_args.kwargs
or call_args.kwargs["truncation"] is False
)
@pytest.mark.asyncio
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088]) # Truncated
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
[101, 7592, 2088]
) # Truncated
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
config=RenderConfig(
max_length=100,
truncate_prompt_tokens=50))
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100, truncate_prompt_tokens=50),
)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
@@ -142,14 +141,14 @@ class TestRenderPrompt:
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
# Test that negative truncation uses model's max_model_len
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088]) # Truncated to max_model_len
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
[101, 7592, 2088]
) # Truncated to max_model_len
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
config=RenderConfig(
max_length=200,
truncate_prompt_tokens=-1))
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=200, truncate_prompt_tokens=-1),
)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
@@ -159,12 +158,11 @@ class TestRenderPrompt:
@pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108,
109] # 10 tokens
results = await renderer.render_prompt(prompt_or_prompts=long_tokens,
config=RenderConfig(
max_length=100,
truncate_prompt_tokens=5))
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
results = await renderer.render_prompt(
prompt_or_prompts=long_tokens,
config=RenderConfig(max_length=100, truncate_prompt_tokens=5),
)
assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
@@ -175,30 +173,30 @@ class TestRenderPrompt:
long_tokens = list(range(150)) # Exceeds max_model_len=100
with pytest.raises(ValueError, match="maximum context length"):
await renderer.render_prompt(prompt_or_prompts=long_tokens,
config=RenderConfig(max_length=100))
await renderer.render_prompt(
prompt_or_prompts=long_tokens, config=RenderConfig(max_length=100)
)
@pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, mock_model_config):
renderer_no_tokenizer = CompletionRenderer(
model_config=mock_model_config,
tokenizer=None,
async_tokenizer_pool={})
model_config=mock_model_config, tokenizer=None, async_tokenizer_pool={}
)
with pytest.raises(ValueError, match="No tokenizer available"):
await renderer_no_tokenizer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
)
@pytest.mark.asyncio
async def test_token_input_with_needs_detokenization(
self, renderer, mock_async_tokenizer):
self, renderer, mock_async_tokenizer
):
# When needs_detokenization=True for token inputs, renderer should
# use the async tokenizer to decode and include the original text
# in the returned prompt object.
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
tokens = [1, 2, 3, 4]
results = await renderer.render_prompt(
@@ -213,7 +211,6 @@ class TestRenderPrompt:
class TestRenderEmbedPrompt:
def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
"""Helper to create base64-encoded tensor bytes"""
buffer = io.BytesIO()
@@ -244,9 +241,7 @@ class TestRenderEmbedPrompt:
torch.randn(8, 512, dtype=torch.float32),
torch.randn(12, 512, dtype=torch.float32),
]
embed_bytes_list = [
self._create_test_embed_bytes(t) for t in test_tensors
]
embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes_list,
@@ -307,13 +302,10 @@ class TestRenderEmbedPrompt:
assert results[0]["prompt_embeds"].shape == (10, 768)
@pytest.mark.asyncio
async def test_both_prompts_and_embeds(self, renderer,
mock_async_tokenizer):
async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
# Set up text tokenization
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 102, 103])
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
mock_async_tokenizer.return_value = MockTokenizerResult([101, 102, 103])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
# Create embed
test_tensor = torch.randn(5, 256, dtype=torch.float32)