Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -42,8 +42,7 @@ class RenderConfig:
|
||||
needs_detokenization: Optional[bool] = False
|
||||
"""If True, detokenize IDs back to text for inclusion in outputs."""
|
||||
|
||||
def verify_truncate_prompt_tokens(
|
||||
self, model_config: ModelConfig) -> Optional[int]:
|
||||
def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> Optional[int]:
|
||||
"""Validate and normalize `truncate_prompt_tokens` parameter."""
|
||||
truncate_prompt_tokens = self.truncate_prompt_tokens
|
||||
if truncate_prompt_tokens is None:
|
||||
@@ -59,7 +58,8 @@ class RenderConfig:
|
||||
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
|
||||
raise ValueError(
|
||||
f"{truncate_prompt_tokens=} cannot be greater than "
|
||||
f"{max_length=}. Please select a smaller truncation size.")
|
||||
f"{max_length=}. Please select a smaller truncation size."
|
||||
)
|
||||
|
||||
return truncate_prompt_tokens
|
||||
|
||||
@@ -67,13 +67,13 @@ class RenderConfig:
|
||||
class BaseRenderer(ABC):
|
||||
"""
|
||||
Base class for unified input processing and rendering.
|
||||
|
||||
|
||||
The Renderer serves as a unified input processor that consolidates
|
||||
tokenization, chat template formatting, and multimodal input handling
|
||||
into a single component.
|
||||
It converts high-level API requests (OpenAI-style JSON) into token IDs and
|
||||
multimodal features ready for engine consumption.
|
||||
|
||||
|
||||
Key responsibilities:
|
||||
- Convert text prompts to token sequences with proper special tokens
|
||||
- Apply chat templates and format conversations
|
||||
@@ -112,7 +112,7 @@ class BaseRenderer(ABC):
|
||||
- ``list[int]``: Single pre-tokenized sequence.
|
||||
- ``list[list[int]]``: Batch of pre-tokenized sequences.
|
||||
config: Render configuration controlling how prompts are prepared
|
||||
(e.g., tokenization and length handling).
|
||||
(e.g., tokenization and length handling).
|
||||
|
||||
Returns:
|
||||
list[EngineTokensPrompt]: Engine-ready token prompts.
|
||||
@@ -126,8 +126,9 @@ class BaseRenderer(ABC):
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_or_prompts: Optional[
|
||||
Union[str, list[str], list[int], list[list[int]]]
|
||||
] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
config: RenderConfig,
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
@@ -144,7 +145,7 @@ class BaseRenderer(ABC):
|
||||
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
|
||||
torch-saved tensor to be used as prompt embeddings.
|
||||
config: Render configuration controlling how prompts are prepared
|
||||
(e.g., tokenization and length handling).
|
||||
(e.g., tokenization and length handling).
|
||||
|
||||
Returns:
|
||||
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
@@ -195,13 +196,13 @@ class BaseRenderer(ABC):
|
||||
|
||||
|
||||
class CompletionRenderer(BaseRenderer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: Optional[AnyTokenizer] = None,
|
||||
async_tokenizer_pool: Optional[dict[AnyTokenizer,
|
||||
AsyncMicrobatchTokenizer]] = None,
|
||||
async_tokenizer_pool: Optional[
|
||||
dict[AnyTokenizer, AsyncMicrobatchTokenizer]
|
||||
] = None,
|
||||
):
|
||||
super().__init__(model_config, tokenizer)
|
||||
self.async_tokenizer_pool = async_tokenizer_pool
|
||||
@@ -214,28 +215,31 @@ class CompletionRenderer(BaseRenderer):
|
||||
config: RenderConfig,
|
||||
) -> list[EngineTokensPrompt]:
|
||||
"""Implementation of prompt rendering for completion-style requests.
|
||||
|
||||
|
||||
Uses async tokenizer pooling for improved performance. See base class
|
||||
for detailed parameter documentation.
|
||||
"""
|
||||
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
|
||||
self.model_config)
|
||||
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
tasks = (self._create_prompt(
|
||||
prompt_input,
|
||||
config=config,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
) for prompt_input in parse_raw_prompts(prompt_or_prompts))
|
||||
tasks = (
|
||||
self._create_prompt(
|
||||
prompt_input,
|
||||
config=config,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
for prompt_input in parse_raw_prompts(prompt_or_prompts)
|
||||
)
|
||||
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_or_prompts: Optional[
|
||||
Union[str, list[str], list[int], list[list[int]]]
|
||||
] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
config: RenderConfig,
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
@@ -243,8 +247,7 @@ class CompletionRenderer(BaseRenderer):
|
||||
Render text/token prompts and/or precomputed embedding prompts. At
|
||||
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
||||
"""
|
||||
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(
|
||||
self.model_config)
|
||||
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
@@ -252,8 +255,10 @@ class CompletionRenderer(BaseRenderer):
|
||||
|
||||
if prompt_embeds is not None:
|
||||
rendered.extend(
|
||||
self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
|
||||
config.cache_salt))
|
||||
self.load_prompt_embeds(
|
||||
prompt_embeds, truncate_prompt_tokens, config.cache_salt
|
||||
)
|
||||
)
|
||||
if prompt_or_prompts is None or prompt_or_prompts == "":
|
||||
return rendered
|
||||
|
||||
@@ -266,8 +271,8 @@ class CompletionRenderer(BaseRenderer):
|
||||
return rendered
|
||||
|
||||
def _maybe_apply_truncation(
|
||||
self, token_ids: list[int],
|
||||
truncate_prompt_tokens: Optional[int]) -> list[int]:
|
||||
self, token_ids: list[int], truncate_prompt_tokens: Optional[int]
|
||||
) -> list[int]:
|
||||
"""Apply truncation to token sequence."""
|
||||
if truncate_prompt_tokens is None:
|
||||
return token_ids
|
||||
@@ -319,24 +324,26 @@ class CompletionRenderer(BaseRenderer):
|
||||
async_tokenizer = self._get_async_tokenizer()
|
||||
|
||||
# Handle encoder-specific preprocessing
|
||||
if (self.model_config.encoder_config is not None
|
||||
and self.model_config.encoder_config.get(
|
||||
"do_lower_case", False)):
|
||||
if (
|
||||
self.model_config.encoder_config is not None
|
||||
and self.model_config.encoder_config.get("do_lower_case", False)
|
||||
):
|
||||
text = text.lower()
|
||||
|
||||
# Tokenize texts
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = await async_tokenizer(
|
||||
text, add_special_tokens=add_special_tokens)
|
||||
encoded = await async_tokenizer(text, add_special_tokens=add_special_tokens)
|
||||
else:
|
||||
encoded = await async_tokenizer(
|
||||
text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=truncate_prompt_tokens)
|
||||
max_length=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
return self._create_tokens_prompt(encoded.input_ids, max_length,
|
||||
cache_salt, text)
|
||||
return self._create_tokens_prompt(
|
||||
encoded.input_ids, max_length, cache_salt, text
|
||||
)
|
||||
|
||||
async def _create_prompt_from_token_ids(
|
||||
self,
|
||||
@@ -347,18 +354,19 @@ class CompletionRenderer(BaseRenderer):
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
) -> EngineTokensPrompt:
|
||||
"""Optionally detokenize token IDs and build a tokens prompt."""
|
||||
token_ids = self._maybe_apply_truncation(token_ids,
|
||||
truncate_prompt_tokens)
|
||||
token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens)
|
||||
|
||||
prompt = None
|
||||
if needs_detokenization:
|
||||
async_tokenizer = self._get_async_tokenizer()
|
||||
prompt = await async_tokenizer.decode(token_ids)
|
||||
|
||||
return self._create_tokens_prompt(token_ids=token_ids,
|
||||
max_length=max_length,
|
||||
cache_salt=cache_salt,
|
||||
prompt=prompt)
|
||||
return self._create_tokens_prompt(
|
||||
token_ids=token_ids,
|
||||
max_length=max_length,
|
||||
cache_salt=cache_salt,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
||||
"""Get or create async tokenizer using shared pool."""
|
||||
@@ -368,8 +376,7 @@ class CompletionRenderer(BaseRenderer):
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"No tokenizer available for text input processing")
|
||||
raise ValueError("No tokenizer available for text input processing")
|
||||
|
||||
if self.async_tokenizer_pool is None:
|
||||
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
|
||||
@@ -393,7 +400,8 @@ class CompletionRenderer(BaseRenderer):
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is {max_length} tokens. "
|
||||
f"However, your request has {len(token_ids)} input tokens. "
|
||||
"Please reduce the length of the input messages.")
|
||||
"Please reduce the length of the input messages."
|
||||
)
|
||||
|
||||
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
|
||||
if cache_salt is not None:
|
||||
|
||||
Reference in New Issue
Block a user