Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -42,8 +42,7 @@ class RenderConfig:
needs_detokenization: Optional[bool] = False
"""If True, detokenize IDs back to text for inclusion in outputs."""
def verify_truncate_prompt_tokens(
self, model_config: ModelConfig) -> Optional[int]:
def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> Optional[int]:
"""Validate and normalize `truncate_prompt_tokens` parameter."""
truncate_prompt_tokens = self.truncate_prompt_tokens
if truncate_prompt_tokens is None:
@@ -59,7 +58,8 @@ class RenderConfig:
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
raise ValueError(
f"{truncate_prompt_tokens=} cannot be greater than "
f"{max_length=}. Please select a smaller truncation size.")
f"{max_length=}. Please select a smaller truncation size."
)
return truncate_prompt_tokens
@@ -67,13 +67,13 @@ class RenderConfig:
class BaseRenderer(ABC):
"""
Base class for unified input processing and rendering.
The Renderer serves as a unified input processor that consolidates
tokenization, chat template formatting, and multimodal input handling
into a single component.
It converts high-level API requests (OpenAI-style JSON) into token IDs and
multimodal features ready for engine consumption.
Key responsibilities:
- Convert text prompts to token sequences with proper special tokens
- Apply chat templates and format conversations
@@ -112,7 +112,7 @@ class BaseRenderer(ABC):
- ``list[int]``: Single pre-tokenized sequence.
- ``list[list[int]]``: Batch of pre-tokenized sequences.
config: Render configuration controlling how prompts are prepared
(e.g., tokenization and length handling).
(e.g., tokenization and length handling).
Returns:
list[EngineTokensPrompt]: Engine-ready token prompts.
@@ -126,8 +126,9 @@ class BaseRenderer(ABC):
async def render_prompt_and_embeds(
self,
*,
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_or_prompts: Optional[
Union[str, list[str], list[int], list[list[int]]]
] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
config: RenderConfig,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
@@ -144,7 +145,7 @@ class BaseRenderer(ABC):
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
torch-saved tensor to be used as prompt embeddings.
config: Render configuration controlling how prompts are prepared
(e.g., tokenization and length handling).
(e.g., tokenization and length handling).
Returns:
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
@@ -195,13 +196,13 @@ class BaseRenderer(ABC):
class CompletionRenderer(BaseRenderer):
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[AnyTokenizer] = None,
async_tokenizer_pool: Optional[dict[AnyTokenizer,
AsyncMicrobatchTokenizer]] = None,
async_tokenizer_pool: Optional[
dict[AnyTokenizer, AsyncMicrobatchTokenizer]
] = None,
):
super().__init__(model_config, tokenizer)
self.async_tokenizer_pool = async_tokenizer_pool
@@ -214,28 +215,31 @@ class CompletionRenderer(BaseRenderer):
config: RenderConfig,
) -> list[EngineTokensPrompt]:
"""Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class
for detailed parameter documentation.
"""
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
self.model_config)
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
if truncate_prompt_tokens == 0:
return []
tasks = (self._create_prompt(
prompt_input,
config=config,
truncate_prompt_tokens=truncate_prompt_tokens,
) for prompt_input in parse_raw_prompts(prompt_or_prompts))
tasks = (
self._create_prompt(
prompt_input,
config=config,
truncate_prompt_tokens=truncate_prompt_tokens,
)
for prompt_input in parse_raw_prompts(prompt_or_prompts)
)
return await asyncio.gather(*tasks)
async def render_prompt_and_embeds(
self,
*,
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_or_prompts: Optional[
Union[str, list[str], list[int], list[list[int]]]
] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
config: RenderConfig,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
@@ -243,8 +247,7 @@ class CompletionRenderer(BaseRenderer):
Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
"""
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
self.model_config)
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
if truncate_prompt_tokens == 0:
return []
@@ -252,8 +255,10 @@ class CompletionRenderer(BaseRenderer):
if prompt_embeds is not None:
rendered.extend(
self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
config.cache_salt))
self.load_prompt_embeds(
prompt_embeds, truncate_prompt_tokens, config.cache_salt
)
)
if prompt_or_prompts is None or prompt_or_prompts == "":
return rendered
@@ -266,8 +271,8 @@ class CompletionRenderer(BaseRenderer):
return rendered
def _maybe_apply_truncation(
self, token_ids: list[int],
truncate_prompt_tokens: Optional[int]) -> list[int]:
self, token_ids: list[int], truncate_prompt_tokens: Optional[int]
) -> list[int]:
"""Apply truncation to token sequence."""
if truncate_prompt_tokens is None:
return token_ids
@@ -319,24 +324,26 @@ class CompletionRenderer(BaseRenderer):
async_tokenizer = self._get_async_tokenizer()
# Handle encoder-specific preprocessing
if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get(
"do_lower_case", False)):
if (
self.model_config.encoder_config is not None
and self.model_config.encoder_config.get("do_lower_case", False)
):
text = text.lower()
# Tokenize texts
if truncate_prompt_tokens is None:
encoded = await async_tokenizer(
text, add_special_tokens=add_special_tokens)
encoded = await async_tokenizer(text, add_special_tokens=add_special_tokens)
else:
encoded = await async_tokenizer(
text,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
max_length=truncate_prompt_tokens,
)
return self._create_tokens_prompt(encoded.input_ids, max_length,
cache_salt, text)
return self._create_tokens_prompt(
encoded.input_ids, max_length, cache_salt, text
)
async def _create_prompt_from_token_ids(
self,
@@ -347,18 +354,19 @@ class CompletionRenderer(BaseRenderer):
needs_detokenization: Optional[bool] = False,
) -> EngineTokensPrompt:
"""Optionally detokenize token IDs and build a tokens prompt."""
token_ids = self._maybe_apply_truncation(token_ids,
truncate_prompt_tokens)
token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens)
prompt = None
if needs_detokenization:
async_tokenizer = self._get_async_tokenizer()
prompt = await async_tokenizer.decode(token_ids)
return self._create_tokens_prompt(token_ids=token_ids,
max_length=max_length,
cache_salt=cache_salt,
prompt=prompt)
return self._create_tokens_prompt(
token_ids=token_ids,
max_length=max_length,
cache_salt=cache_salt,
prompt=prompt,
)
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
"""Get or create async tokenizer using shared pool."""
@@ -368,8 +376,7 @@ class CompletionRenderer(BaseRenderer):
tokenizer = self.tokenizer
if self.tokenizer is None:
raise ValueError(
"No tokenizer available for text input processing")
raise ValueError("No tokenizer available for text input processing")
if self.async_tokenizer_pool is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
@@ -393,7 +400,8 @@ class CompletionRenderer(BaseRenderer):
raise ValueError(
f"This model's maximum context length is {max_length} tokens. "
f"However, your request has {len(token_ids)} input tokens. "
"Please reduce the length of the input messages.")
"Please reduce the length of the input messages."
)
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
if cache_salt is not None: