Support embedding models in V1 (#16188)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
Maximilien de Bayser
2025-06-19 01:36:33 -03:00
committed by GitHub
parent 4959915089
commit 799397ee4f
56 changed files with 889 additions and 281 deletions

View File

@@ -136,8 +136,8 @@ class Processor:
Should raise ValueError if unsupported for API Server.
"""
if not isinstance(params, SamplingParams):
raise ValueError("V1 does not yet support Pooling models.")
if isinstance(params, PoolingParams):
return
self._validate_logprobs(params)
self._validate_sampling_params(params, lora_request)
@@ -263,18 +263,22 @@ class Processor:
if encoder_inputs is not None:
raise NotImplementedError
assert isinstance(params, SamplingParams)
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
sampling_params = None
pooling_params = None
if isinstance(params, SamplingParams):
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
else:
pooling_params = params.clone()
# Multimodal related.
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
@@ -331,6 +335,7 @@ class Processor:
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params,
pooling_params=pooling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,