[Refactor] Pass full VllmConfig to Renderer (#34485)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -59,11 +59,16 @@ class MockModelConfig:
|
|||||||
return self.diff_sampling_param or {}
|
return self.diff_sampling_param or {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockVllmConfig:
|
||||||
|
model_config: MockModelConfig
|
||||||
|
|
||||||
|
|
||||||
def _build_renderer(model_config: MockModelConfig):
|
def _build_renderer(model_config: MockModelConfig):
|
||||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||||
|
|
||||||
return HfRenderer(
|
return HfRenderer(
|
||||||
model_config,
|
MockVllmConfig(model_config),
|
||||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,11 @@ class MockModelConfig:
|
|||||||
return self.diff_sampling_param or {}
|
return self.diff_sampling_param or {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockVllmConfig:
|
||||||
|
model_config: MockModelConfig
|
||||||
|
|
||||||
|
|
||||||
def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
||||||
models = OpenAIServingModels(
|
models = OpenAIServingModels(
|
||||||
engine_client=engine,
|
engine_client=engine,
|
||||||
@@ -74,7 +79,7 @@ def _build_renderer(model_config: MockModelConfig):
|
|||||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||||
|
|
||||||
return HfRenderer(
|
return HfRenderer(
|
||||||
model_config,
|
MockVllmConfig(model_config),
|
||||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -57,6 +57,11 @@ class MockModelConfig:
|
|||||||
return self.diff_sampling_param or {}
|
return self.diff_sampling_param or {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockVllmConfig:
|
||||||
|
model_config: MockModelConfig
|
||||||
|
|
||||||
|
|
||||||
class MockLoRAResolver(LoRAResolver):
|
class MockLoRAResolver(LoRAResolver):
|
||||||
async def resolve_lora(
|
async def resolve_lora(
|
||||||
self, base_model_name: str, lora_name: str
|
self, base_model_name: str, lora_name: str
|
||||||
@@ -91,7 +96,7 @@ def _build_renderer(model_config: MockModelConfig):
|
|||||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||||
|
|
||||||
return HfRenderer(
|
return HfRenderer(
|
||||||
model_config,
|
MockVllmConfig(model_config),
|
||||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -534,11 +534,16 @@ class MockModelConfig:
|
|||||||
return self.diff_sampling_param or {}
|
return self.diff_sampling_param or {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockVllmConfig:
|
||||||
|
model_config: MockModelConfig
|
||||||
|
|
||||||
|
|
||||||
def _build_renderer(model_config: MockModelConfig):
|
def _build_renderer(model_config: MockModelConfig):
|
||||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||||
|
|
||||||
return HfRenderer(
|
return HfRenderer(
|
||||||
model_config,
|
MockVllmConfig(model_config),
|
||||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -749,7 +754,10 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
|
|||||||
mock_engine.io_processor = MagicMock()
|
mock_engine.io_processor = MagicMock()
|
||||||
|
|
||||||
mock_tokenizer = MagicMock(spec=MistralTokenizer)
|
mock_tokenizer = MagicMock(spec=MistralTokenizer)
|
||||||
mock_renderer = MistralRenderer(mock_engine.model_config, tokenizer_kwargs={})
|
mock_renderer = MistralRenderer(
|
||||||
|
MockVllmConfig(mock_engine.model_config),
|
||||||
|
tokenizer_kwargs={},
|
||||||
|
)
|
||||||
mock_renderer._tokenizer = mock_tokenizer
|
mock_renderer._tokenizer = mock_tokenizer
|
||||||
# Force the Mistral chat template renderer to return token IDs.
|
# Force the Mistral chat template renderer to return token IDs.
|
||||||
# Choose a prompt length that is < max_model_len, but large enough that
|
# Choose a prompt length that is < max_model_len, but large enough that
|
||||||
@@ -788,7 +796,10 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
|
|||||||
mock_engine.io_processor = MagicMock()
|
mock_engine.io_processor = MagicMock()
|
||||||
|
|
||||||
mock_tokenizer = MagicMock(spec=MistralTokenizer)
|
mock_tokenizer = MagicMock(spec=MistralTokenizer)
|
||||||
mock_renderer = MistralRenderer(mock_engine.model_config, tokenizer_kwargs={})
|
mock_renderer = MistralRenderer(
|
||||||
|
MockVllmConfig(mock_engine.model_config),
|
||||||
|
tokenizer_kwargs={},
|
||||||
|
)
|
||||||
mock_renderer._tokenizer = mock_tokenizer
|
mock_renderer._tokenizer = mock_tokenizer
|
||||||
# prompt_token_ids length == max_model_len should be rejected for
|
# prompt_token_ids length == max_model_len should be rejected for
|
||||||
# completion-like requests (ChatCompletionRequest).
|
# completion-like requests (ChatCompletionRequest).
|
||||||
|
|||||||
@@ -40,6 +40,11 @@ class MockModelConfig:
|
|||||||
is_encoder_decoder: bool = False
|
is_encoder_decoder: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockVllmConfig:
|
||||||
|
model_config: MockModelConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DummyTokenizer:
|
class DummyTokenizer:
|
||||||
truncation_side: str = "left"
|
truncation_side: str = "left"
|
||||||
@@ -72,7 +77,7 @@ def _build_renderer(
|
|||||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||||
|
|
||||||
renderer = HfRenderer(
|
renderer = HfRenderer(
|
||||||
model_config,
|
MockVllmConfig(model_config),
|
||||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -104,14 +109,14 @@ class TestValidatePrompt:
|
|||||||
renderer = _build_renderer(MockModelConfig())
|
renderer = _build_renderer(MockModelConfig())
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="at least one prompt"):
|
with pytest.raises(ValueError, match="at least one prompt"):
|
||||||
renderer.render_prompts(_preprocess_prompt(renderer.config, []))
|
renderer.render_prompts(_preprocess_prompt(renderer.model_config, []))
|
||||||
|
|
||||||
def test_invalid_type(self):
|
def test_invalid_type(self):
|
||||||
renderer = _build_renderer(MockModelConfig())
|
renderer = _build_renderer(MockModelConfig())
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="should be a list of integers"):
|
with pytest.raises(TypeError, match="should be a list of integers"):
|
||||||
renderer.render_prompts(
|
renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, [[1, 2], ["foo", "bar"]]) # type: ignore[arg-type]
|
_preprocess_prompt(renderer.model_config, [[1, 2], ["foo", "bar"]]) # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -120,7 +125,9 @@ class TestRenderPrompt:
|
|||||||
renderer = _build_renderer(MockModelConfig())
|
renderer = _build_renderer(MockModelConfig())
|
||||||
|
|
||||||
tokens = [101, 7592, 2088]
|
tokens = [101, 7592, 2088]
|
||||||
prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
|
prompts = renderer.render_prompts(
|
||||||
|
_preprocess_prompt(renderer.model_config, tokens)
|
||||||
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
TokenizeParams(max_total_tokens=100),
|
TokenizeParams(max_total_tokens=100),
|
||||||
@@ -134,7 +141,7 @@ class TestRenderPrompt:
|
|||||||
|
|
||||||
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
|
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, token_lists)
|
_preprocess_prompt(renderer.model_config, token_lists)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -151,7 +158,7 @@ class TestRenderPrompt:
|
|||||||
|
|
||||||
text_input = "x" * 10
|
text_input = "x" * 10
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, text_input)
|
_preprocess_prompt(renderer.model_config, text_input)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -166,7 +173,7 @@ class TestRenderPrompt:
|
|||||||
|
|
||||||
text_list_input = ["x" * 10, "x" * 12, "x" * 14]
|
text_list_input = ["x" * 10, "x" * 12, "x" * 14]
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, text_list_input)
|
_preprocess_prompt(renderer.model_config, text_list_input)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -181,7 +188,7 @@ class TestRenderPrompt:
|
|||||||
renderer = _build_renderer(MockModelConfig())
|
renderer = _build_renderer(MockModelConfig())
|
||||||
|
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, "x" * 200)
|
_preprocess_prompt(renderer.model_config, "x" * 200)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -195,7 +202,7 @@ class TestRenderPrompt:
|
|||||||
renderer = _build_renderer(MockModelConfig())
|
renderer = _build_renderer(MockModelConfig())
|
||||||
|
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, "x" * 200)
|
_preprocess_prompt(renderer.model_config, "x" * 200)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -209,7 +216,7 @@ class TestRenderPrompt:
|
|||||||
renderer = _build_renderer(MockModelConfig())
|
renderer = _build_renderer(MockModelConfig())
|
||||||
|
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, "x" * 200)
|
_preprocess_prompt(renderer.model_config, "x" * 200)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -224,7 +231,7 @@ class TestRenderPrompt:
|
|||||||
|
|
||||||
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
|
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, long_tokens)
|
_preprocess_prompt(renderer.model_config, long_tokens)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -240,7 +247,7 @@ class TestRenderPrompt:
|
|||||||
|
|
||||||
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
|
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, long_tokens)
|
_preprocess_prompt(renderer.model_config, long_tokens)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -257,7 +264,7 @@ class TestRenderPrompt:
|
|||||||
# Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
|
# Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
|
||||||
long_tokens = "x" * 150
|
long_tokens = "x" * 150
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, long_tokens)
|
_preprocess_prompt(renderer.model_config, long_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
@@ -278,7 +285,7 @@ class TestRenderPrompt:
|
|||||||
# Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
|
# Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
|
||||||
long_tokens = "x" * 150
|
long_tokens = "x" * 150
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, long_tokens)
|
_preprocess_prompt(renderer.model_config, long_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
@@ -299,7 +306,7 @@ class TestRenderPrompt:
|
|||||||
|
|
||||||
long_tokens = list(range(150)) # Exceeds max_total_tokens=100
|
long_tokens = list(range(150)) # Exceeds max_total_tokens=100
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, long_tokens)
|
_preprocess_prompt(renderer.model_config, long_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
@@ -315,7 +322,7 @@ class TestRenderPrompt:
|
|||||||
renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
|
renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
|
||||||
|
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, "Hello world")
|
_preprocess_prompt(renderer.model_config, "Hello world")
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
|
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
|
||||||
@@ -328,7 +335,9 @@ class TestRenderPrompt:
|
|||||||
renderer = _build_renderer(MockModelConfig())
|
renderer = _build_renderer(MockModelConfig())
|
||||||
|
|
||||||
tokens = [1, 2, 3, 4]
|
tokens = [1, 2, 3, 4]
|
||||||
prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
|
prompts = renderer.render_prompts(
|
||||||
|
_preprocess_prompt(renderer.model_config, tokens)
|
||||||
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
TokenizeParams(
|
TokenizeParams(
|
||||||
@@ -358,7 +367,7 @@ class TestRenderEmbedPrompt:
|
|||||||
embed_bytes = self._create_test_embed_bytes(tensor_input)
|
embed_bytes = self._create_test_embed_bytes(tensor_input)
|
||||||
|
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(renderer.config, embed_bytes)
|
_preprocess_prompt(renderer.model_config, embed_bytes)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -379,7 +388,7 @@ class TestRenderEmbedPrompt:
|
|||||||
|
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(
|
_preprocess_prompt(
|
||||||
renderer.config,
|
renderer.model_config,
|
||||||
[self._create_test_embed_bytes(t) for t in tensor_inputs],
|
[self._create_test_embed_bytes(t) for t in tensor_inputs],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -400,7 +409,7 @@ class TestRenderEmbedPrompt:
|
|||||||
|
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(
|
_preprocess_prompt(
|
||||||
renderer.config, self._create_test_embed_bytes(tensor_input)
|
renderer.model_config, self._create_test_embed_bytes(tensor_input)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
@@ -427,7 +436,7 @@ class TestRenderEmbedPrompt:
|
|||||||
|
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(
|
_preprocess_prompt(
|
||||||
renderer.config, self._create_test_embed_bytes(tensor_input)
|
renderer.model_config, self._create_test_embed_bytes(tensor_input)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
@@ -446,7 +455,7 @@ class TestRenderEmbedPrompt:
|
|||||||
|
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(
|
_preprocess_prompt(
|
||||||
renderer.config, self._create_test_embed_bytes(tensor_input)
|
renderer.model_config, self._create_test_embed_bytes(tensor_input)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
results = renderer.tokenize_prompts(
|
results = renderer.tokenize_prompts(
|
||||||
@@ -466,7 +475,7 @@ class TestRenderEmbedPrompt:
|
|||||||
|
|
||||||
prompts = renderer.render_prompts(
|
prompts = renderer.render_prompts(
|
||||||
_preprocess_prompt(
|
_preprocess_prompt(
|
||||||
renderer.config,
|
renderer.model_config,
|
||||||
[text_input, self._create_test_embed_bytes(tensor_input)],
|
[text_input, self._create_test_embed_bytes(tensor_input)],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -38,6 +38,11 @@ class MockModelConfig:
|
|||||||
is_encoder_decoder: bool = False
|
is_encoder_decoder: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockVllmConfig:
|
||||||
|
model_config: MockModelConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_mistral_tokenizer_does_not_block_event_loop():
|
async def test_async_mistral_tokenizer_does_not_block_event_loop():
|
||||||
expected_tokens = [1, 2, 3]
|
expected_tokens = [1, 2, 3]
|
||||||
@@ -50,7 +55,10 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
|
|||||||
mock_model_config = MockModelConfig(skip_tokenizer_init=True)
|
mock_model_config = MockModelConfig(skip_tokenizer_init=True)
|
||||||
mock_tokenizer = Mock(spec=MistralTokenizer)
|
mock_tokenizer = Mock(spec=MistralTokenizer)
|
||||||
mock_tokenizer.apply_chat_template = mocked_apply_chat_template
|
mock_tokenizer.apply_chat_template = mocked_apply_chat_template
|
||||||
mock_renderer = MistralRenderer(mock_model_config, tokenizer_kwargs={})
|
mock_renderer = MistralRenderer(
|
||||||
|
MockVllmConfig(mock_model_config),
|
||||||
|
tokenizer_kwargs={},
|
||||||
|
)
|
||||||
mock_renderer._tokenizer = mock_tokenizer
|
mock_renderer._tokenizer = mock_tokenizer
|
||||||
|
|
||||||
task = mock_renderer.render_messages_async([], ChatParams())
|
task = mock_renderer.render_messages_async([], ChatParams())
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
|
|
||||||
pytestmark = pytest.mark.cpu_test
|
pytestmark = pytest.mark.cpu_test
|
||||||
@@ -20,7 +20,8 @@ pytestmark = pytest.mark.cpu_test
|
|||||||
)
|
)
|
||||||
def test_preprocessor_always_mm_code_path(model_id, prompt):
|
def test_preprocessor_always_mm_code_path(model_id, prompt):
|
||||||
model_config = ModelConfig(model=model_id)
|
model_config = ModelConfig(model=model_id)
|
||||||
input_preprocessor = InputPreprocessor(model_config)
|
vllm_config = VllmConfig(model_config=model_config)
|
||||||
|
input_preprocessor = InputPreprocessor(vllm_config)
|
||||||
|
|
||||||
# HF processor adds sep token
|
# HF processor adds sep token
|
||||||
tokenizer = input_preprocessor.get_tokenizer()
|
tokenizer = input_preprocessor.get_tokenizer()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Any, overload
|
|||||||
|
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from vllm.config import ModelConfig, ObservabilityConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||||
@@ -54,17 +54,16 @@ logger = init_logger(__name__)
|
|||||||
class InputPreprocessor:
|
class InputPreprocessor:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
vllm_config: VllmConfig,
|
||||||
observability_config: ObservabilityConfig | None = None,
|
|
||||||
renderer: BaseRenderer | None = None,
|
renderer: BaseRenderer | None = None,
|
||||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
|
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.observability_config = observability_config
|
self.observability_config = vllm_config.observability_config
|
||||||
self.renderer = renderer or renderer_from_config(model_config)
|
self.renderer = renderer or renderer_from_config(vllm_config)
|
||||||
self.mm_registry = mm_registry
|
self.mm_registry = mm_registry
|
||||||
self.mm_processor_cache = mm_processor_cache
|
self.mm_processor_cache = mm_processor_cache
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from .inputs.preprocess import extract_target_prompt
|
|||||||
from .params import ChatParams, TokenizeParams
|
from .params import ChatParams, TokenizeParams
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.entrypoints.chat_utils import (
|
from vllm.entrypoints.chat_utils import (
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ConversationMessage,
|
ConversationMessage,
|
||||||
@@ -35,15 +35,15 @@ class BaseRenderer(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: "ModelConfig",
|
config: "VllmConfig",
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "BaseRenderer":
|
) -> "BaseRenderer":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(self, config: "ModelConfig") -> None:
|
def __init__(self, config: "VllmConfig") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.model_config = config.model_config
|
||||||
|
|
||||||
# Lazy initialization since offline LLM doesn't use async
|
# Lazy initialization since offline LLM doesn't use async
|
||||||
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
|
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
|
||||||
@@ -90,7 +90,7 @@ class BaseRenderer(ABC):
|
|||||||
prompt: DictPrompt | bytes,
|
prompt: DictPrompt | bytes,
|
||||||
) -> DictPrompt:
|
) -> DictPrompt:
|
||||||
if isinstance(prompt, bytes):
|
if isinstance(prompt, bytes):
|
||||||
embeds = safe_load_prompt_embeds(self.config, prompt)
|
embeds = safe_load_prompt_embeds(self.model_config, prompt)
|
||||||
prompt = EmbedsPrompt(prompt_embeds=embeds)
|
prompt = EmbedsPrompt(prompt_embeds=embeds)
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
@@ -310,7 +310,7 @@ class BaseRenderer(ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
target_prompt = extract_target_prompt(self.config, prompt)
|
target_prompt = extract_target_prompt(self.model_config, prompt)
|
||||||
target_prompt.update(prompt_extras) # type: ignore[arg-type]
|
target_prompt.update(prompt_extras) # type: ignore[arg-type]
|
||||||
|
|
||||||
# Top-level methods
|
# Top-level methods
|
||||||
@@ -325,7 +325,7 @@ class BaseRenderer(ABC):
|
|||||||
|
|
||||||
# NOTE: Some MM models have non-default `add_special_tokens`
|
# NOTE: Some MM models have non-default `add_special_tokens`
|
||||||
# so we handle tokenization in multi-modal processor
|
# so we handle tokenization in multi-modal processor
|
||||||
if self.config.is_multimodal_model:
|
if self.model_config.is_multimodal_model:
|
||||||
self._apply_prompt_extras(dict_prompts, prompt_extras)
|
self._apply_prompt_extras(dict_prompts, prompt_extras)
|
||||||
return dict_prompts
|
return dict_prompts
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.entrypoints.chat_utils import (
|
from vllm.entrypoints.chat_utils import (
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ConversationMessage,
|
ConversationMessage,
|
||||||
@@ -26,19 +26,20 @@ class DeepseekV32Renderer(BaseRenderer):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: ModelConfig,
|
config: VllmConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "BaseRenderer":
|
) -> "BaseRenderer":
|
||||||
return cls(config, tokenizer_kwargs)
|
return cls(config, tokenizer_kwargs)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ModelConfig,
|
config: VllmConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if config.skip_tokenizer_init:
|
model_config = self.model_config
|
||||||
|
if model_config.skip_tokenizer_init:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
else:
|
else:
|
||||||
tokenizer = cached_get_tokenizer(
|
tokenizer = cached_get_tokenizer(
|
||||||
@@ -67,7 +68,7 @@ class DeepseekV32Renderer(BaseRenderer):
|
|||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
messages,
|
messages,
|
||||||
self.config,
|
self.model_config,
|
||||||
content_format="string",
|
content_format="string",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -93,7 +94,7 @@ class DeepseekV32Renderer(BaseRenderer):
|
|||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||||
messages,
|
messages,
|
||||||
self.config,
|
self.model_config,
|
||||||
content_format="string",
|
content_format="string",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.entrypoints.chat_utils import (
|
from vllm.entrypoints.chat_utils import (
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ConversationMessage,
|
ConversationMessage,
|
||||||
@@ -25,19 +25,20 @@ class Grok2Renderer(BaseRenderer):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: ModelConfig,
|
config: VllmConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "BaseRenderer":
|
) -> "BaseRenderer":
|
||||||
return cls(config, tokenizer_kwargs)
|
return cls(config, tokenizer_kwargs)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ModelConfig,
|
config: VllmConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if config.skip_tokenizer_init:
|
model_config = self.model_config
|
||||||
|
if model_config.skip_tokenizer_init:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
else:
|
else:
|
||||||
tokenizer = cached_get_tokenizer(
|
tokenizer = cached_get_tokenizer(
|
||||||
@@ -66,7 +67,7 @@ class Grok2Renderer(BaseRenderer):
|
|||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
messages,
|
messages,
|
||||||
self.config,
|
self.model_config,
|
||||||
content_format="string",
|
content_format="string",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -92,7 +93,7 @@ class Grok2Renderer(BaseRenderer):
|
|||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||||
messages,
|
messages,
|
||||||
self.config,
|
self.model_config,
|
||||||
content_format="string",
|
content_format="string",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import jinja2.nodes
|
|||||||
import jinja2.parser
|
import jinja2.parser
|
||||||
import jinja2.sandbox
|
import jinja2.sandbox
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
from vllm.entrypoints.chat_utils import (
|
from vllm.entrypoints.chat_utils import (
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ChatTemplateContentFormat,
|
ChatTemplateContentFormat,
|
||||||
@@ -589,23 +589,24 @@ class HfRenderer(BaseRenderer):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: ModelConfig,
|
config: VllmConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "BaseRenderer":
|
) -> "BaseRenderer":
|
||||||
return cls(config, tokenizer_kwargs)
|
return cls(config, tokenizer_kwargs)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ModelConfig,
|
config: VllmConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
model_config = self.model_config
|
||||||
self.use_unified_vision_chunk = getattr(
|
self.use_unified_vision_chunk = getattr(
|
||||||
config.hf_config, "use_unified_vision_chunk", False
|
model_config.hf_config, "use_unified_vision_chunk", False
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.skip_tokenizer_init:
|
if model_config.skip_tokenizer_init:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
else:
|
else:
|
||||||
tokenizer = cast(
|
tokenizer = cast(
|
||||||
@@ -634,7 +635,7 @@ class HfRenderer(BaseRenderer):
|
|||||||
messages: list[ChatCompletionMessageParam],
|
messages: list[ChatCompletionMessageParam],
|
||||||
params: ChatParams,
|
params: ChatParams,
|
||||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||||
model_config = self.config
|
model_config = self.model_config
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
@@ -688,7 +689,7 @@ class HfRenderer(BaseRenderer):
|
|||||||
messages: list[ChatCompletionMessageParam],
|
messages: list[ChatCompletionMessageParam],
|
||||||
params: ChatParams,
|
params: ChatParams,
|
||||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||||
model_config = self.config
|
model_config = self.model_config
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.entrypoints.chat_utils import (
|
from vllm.entrypoints.chat_utils import (
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ConversationMessage,
|
ConversationMessage,
|
||||||
@@ -54,19 +54,20 @@ class MistralRenderer(BaseRenderer):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: ModelConfig,
|
config: VllmConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "BaseRenderer":
|
) -> "BaseRenderer":
|
||||||
return cls(config, tokenizer_kwargs)
|
return cls(config, tokenizer_kwargs)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ModelConfig,
|
config: VllmConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if config.skip_tokenizer_init:
|
model_config = self.model_config
|
||||||
|
if model_config.skip_tokenizer_init:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
else:
|
else:
|
||||||
tokenizer = cached_get_tokenizer(
|
tokenizer = cached_get_tokenizer(
|
||||||
@@ -100,7 +101,7 @@ class MistralRenderer(BaseRenderer):
|
|||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
messages,
|
messages,
|
||||||
self.config,
|
self.model_config,
|
||||||
content_format="string",
|
content_format="string",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -126,7 +127,7 @@ class MistralRenderer(BaseRenderer):
|
|||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||||
messages,
|
messages,
|
||||||
self.config,
|
self.model_config,
|
||||||
content_format="string",
|
content_format="string",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
|
|||||||
from .base import BaseRenderer
|
from .base import BaseRenderer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class RendererRegistry:
|
|||||||
def load_renderer(
|
def load_renderer(
|
||||||
self,
|
self,
|
||||||
renderer_mode: str,
|
renderer_mode: str,
|
||||||
config: "ModelConfig",
|
config: "VllmConfig",
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> BaseRenderer:
|
) -> BaseRenderer:
|
||||||
renderer_cls = self.load_renderer_cls(renderer_mode)
|
renderer_cls = self.load_renderer_cls(renderer_mode)
|
||||||
@@ -71,12 +71,16 @@ RENDERER_REGISTRY = RendererRegistry(
|
|||||||
"""The global `RendererRegistry` instance."""
|
"""The global `RendererRegistry` instance."""
|
||||||
|
|
||||||
|
|
||||||
def renderer_from_config(config: "ModelConfig", **kwargs):
|
def renderer_from_config(config: "VllmConfig", **kwargs):
|
||||||
|
model_config = config.model_config
|
||||||
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
|
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
|
||||||
config, **kwargs
|
model_config, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.tokenizer_mode == "auto" and config.model_impl == "terratorch":
|
if (
|
||||||
|
model_config.tokenizer_mode == "auto"
|
||||||
|
and model_config.model_impl == "terratorch"
|
||||||
|
):
|
||||||
renderer_mode = "terratorch"
|
renderer_mode = "terratorch"
|
||||||
else:
|
else:
|
||||||
renderer_mode = tokenizer_mode
|
renderer_mode = tokenizer_mode
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.entrypoints.chat_utils import (
|
from vllm.entrypoints.chat_utils import (
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ConversationMessage,
|
ConversationMessage,
|
||||||
@@ -24,15 +24,16 @@ class TerratorchRenderer(BaseRenderer):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: "ModelConfig",
|
config: VllmConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "BaseRenderer":
|
) -> "BaseRenderer":
|
||||||
return cls(config)
|
return cls(config)
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig) -> None:
|
def __init__(self, config: VllmConfig) -> None:
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not config.skip_tokenizer_init:
|
model_config = self.model_config
|
||||||
|
if not model_config.skip_tokenizer_init:
|
||||||
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
|
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -47,7 +48,7 @@ class TerratorchRenderer(BaseRenderer):
|
|||||||
messages: list[ChatCompletionMessageParam],
|
messages: list[ChatCompletionMessageParam],
|
||||||
params: ChatParams,
|
params: ChatParams,
|
||||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||||
model_config = self.config
|
model_config = self.model_config
|
||||||
|
|
||||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
messages,
|
messages,
|
||||||
@@ -68,7 +69,7 @@ class TerratorchRenderer(BaseRenderer):
|
|||||||
messages: list[ChatCompletionMessageParam],
|
messages: list[ChatCompletionMessageParam],
|
||||||
params: ChatParams,
|
params: ChatParams,
|
||||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||||
model_config = self.config
|
model_config = self.model_config
|
||||||
|
|
||||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||||
messages,
|
messages,
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ class AsyncLLM(EngineClient):
|
|||||||
"enabling logging without default stat loggers."
|
"enabling logging without default stat loggers."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.renderer = renderer = renderer_from_config(self.model_config)
|
self.renderer = renderer = renderer_from_config(self.vllm_config)
|
||||||
self.io_processor = get_io_processor(
|
self.io_processor = get_io_processor(
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.model_config.io_processor_plugin,
|
self.model_config.io_processor_plugin,
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class InputProcessor:
|
|||||||
|
|
||||||
self.generation_config_fields = model_config.try_get_generation_config()
|
self.generation_config_fields = model_config.try_get_generation_config()
|
||||||
|
|
||||||
self.renderer = renderer or renderer_from_config(model_config)
|
self.renderer = renderer or renderer_from_config(vllm_config)
|
||||||
self.mm_registry = mm_registry
|
self.mm_registry = mm_registry
|
||||||
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
|
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
|
||||||
|
|
||||||
@@ -75,8 +75,7 @@ class InputProcessor:
|
|||||||
mm_budget.reset_cache() # Not used anymore
|
mm_budget.reset_cache() # Not used anymore
|
||||||
|
|
||||||
self.input_preprocessor = InputPreprocessor(
|
self.input_preprocessor = InputPreprocessor(
|
||||||
model_config,
|
vllm_config,
|
||||||
self.observability_config,
|
|
||||||
renderer=renderer,
|
renderer=renderer,
|
||||||
mm_registry=mm_registry,
|
mm_registry=mm_registry,
|
||||||
mm_processor_cache=self.mm_processor_cache,
|
mm_processor_cache=self.mm_processor_cache,
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class LLMEngine:
|
|||||||
self.dp_group = None
|
self.dp_group = None
|
||||||
self.should_execute_dummy_batch = False
|
self.should_execute_dummy_batch = False
|
||||||
|
|
||||||
self.renderer = renderer = renderer_from_config(self.model_config)
|
self.renderer = renderer = renderer_from_config(self.vllm_config)
|
||||||
self.io_processor = get_io_processor(
|
self.io_processor = get_io_processor(
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.model_config.io_processor_plugin,
|
self.model_config.io_processor_plugin,
|
||||||
|
|||||||
Reference in New Issue
Block a user