[Model Runner V2] Minor refactoring for EncoderRunner (#35628)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-03-01 00:15:39 -08:00
committed by GitHub
parent 87d319c52f
commit da543d1abe
2 changed files with 13 additions and 13 deletions

View File

@@ -13,12 +13,14 @@ from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
class EncoderRunner:
def __init__(
self,
model: SupportsMultiModal,
max_num_tokens: int,
hidden_size: int,
encoder_cache: EncoderCache,
dtype: torch.dtype,
device: torch.device,
):
self.model = model
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.encoder_cache = encoder_cache
@@ -48,25 +50,17 @@ class EncoderRunner:
@torch.inference_mode()
def execute_mm_encoder(
self,
model: SupportsMultiModal,
mm_hashes: list[str],
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
) -> list[torch.Tensor]:
if not mm_hashes:
return []
encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, device=self.device, pin_memory=False
):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs, expected_num_items=num_items
)
encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
return encoder_outputs
def gather_mm_embeddings(
@@ -146,12 +140,11 @@ class EncoderRunner:
@torch.inference_mode()
def get_inputs_embeds(
self,
model: SupportsMultiModal,
input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor],
is_mm_embed: torch.Tensor,
) -> torch.Tensor:
x = model.embed_input_ids(
x = self.model.embed_input_ids(
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
)
# Copy to the pre-allocated buffer for CUDA graphs.

View File

@@ -41,7 +41,9 @@ class DefaultModelState(ModelState):
if self.supports_mm_inputs:
assert encoder_cache is not None
self.encoder_cache = encoder_cache
self.encoder_runner = EncoderRunner(
model=self.model,
max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size,
encoder_cache=encoder_cache,
@@ -82,7 +84,12 @@ class DefaultModelState(ModelState):
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
)
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
if mm_kwargs:
# Execute the multimodal encoder.
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
@@ -92,7 +99,7 @@ class DefaultModelState(ModelState):
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
)
inputs_embeds = self.encoder_runner.get_inputs_embeds(
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
input_batch.input_ids, mm_embeds, is_mm_embed
)
return inputs_embeds[: input_batch.num_tokens_after_padding]