[Core] Whisper Enable Encoder Batching (#29421)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-12-11 22:06:51 +01:00
committed by GitHub
parent 90d6cf921f
commit 0efd9f867c
5 changed files with 87 additions and 25 deletions

View File

@@ -341,3 +341,56 @@ def compute_mm_encoder_budget(
)
return encoder_compute_budget, encoder_cache_size
# NOTE (NickLucche): Temporary implementation for encoder-decoder models that only
# use the manager for scheduling purposes. Encoder-decoder models will eventually
# utilize the cache and this class will fold into EncoderCacheManager, as
# differences with MM models shrink.
class EncoderDecoderCacheManager(EncoderCacheManager):
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
self.freed: list[str] = []
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
return False
def can_allocate(
self,
request: Request,
input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int,
) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
# Not enough compute budget
if num_tokens > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
# Enough free slots
return num_tokens <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots -= num_encoder_tokens
mm_hash = request.mm_features[input_id].identifier
self.freed.append(mm_hash)
def free(self, request: Request) -> None:
for input_id in range(len(request.mm_features)):
self.free_encoder_input(request, input_id)
def get_cached_input_ids(self, request: Request) -> set[int]:
return set(range(len(request.mm_features)))
def get_freed_mm_hashes(self) -> list[str]:
freed = self.freed
self.freed = []
return freed
def free_encoder_input(self, request: Request, input_id: int) -> None:
num_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots += num_tokens