[Core] [Bugfix] Add Input Embeddings (#15428)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: 临景 <linjing.yx@alibaba-inc.com> Co-authored-by: Bryce1010 <bryceyx@gmail.com> Co-authored-by: Nan2018 <nan@protopia.ai> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,16 +2,18 @@
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # noqa
|
||||
import torch
|
||||
from torch import Use # noqa
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SequenceGroup
|
||||
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||
|
||||
from .utils import (append_new_token, append_new_token_seq,
|
||||
append_new_token_seq_group, create_dummy_prompt,
|
||||
@@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
|
||||
), "A partial prefix of C (4 tokens) should be prefilled, with the "
|
||||
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
|
||||
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
|
||||
|
||||
|
||||
def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
|
||||
"""
|
||||
Test that the scheduler does not schedule batches with prompt tokens and
|
||||
prompt embeddings co-mingled.
|
||||
"""
|
||||
block_size = 2
|
||||
max_seq_group = 3
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=100,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
|
||||
# the odd indexed inputs should be passed in via embeddings,
|
||||
# evens via token_ids
|
||||
seq_length = 7
|
||||
embedding_size = 5
|
||||
num_seqs = 11
|
||||
seq_tokens: list[list[int]] = []
|
||||
seq_embeds: list[Optional[torch.Tensor]] = []
|
||||
for i in range(num_seqs):
|
||||
if i % 2:
|
||||
seq_tokens.append(list(range(seq_length)))
|
||||
seq_embeds.append(None)
|
||||
else:
|
||||
seq_tokens.append([0] * seq_length)
|
||||
seq_embeds.append(torch.rand(embedding_size))
|
||||
|
||||
seq_and_seq_groups = [
|
||||
create_dummy_prompt(f"{i}",
|
||||
prompt_tokens=seq_tokens[i],
|
||||
prompt_embeds=seq_embeds[i],
|
||||
block_size=block_size)
|
||||
for i in range(len(seq_tokens))
|
||||
]
|
||||
|
||||
for _, seq_group in seq_and_seq_groups:
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
while not all(seq.is_finished() for seq, _ in seq_and_seq_groups):
|
||||
unfinished_seq_groups = [
|
||||
seq_group for _, seq_group in seq_and_seq_groups
|
||||
if not seq_group.is_finished()
|
||||
]
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) > 0
|
||||
batch_is_prompt_embeds = out.scheduled_seq_groups[
|
||||
0].seq_group.uses_prompt_embeds()
|
||||
expected_scheduled_seq_groups = [
|
||||
seq_group for seq_group in unfinished_seq_groups
|
||||
if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds
|
||||
]
|
||||
|
||||
# We should have as many scheduled groups as possible, without mixing
|
||||
assert len(out.scheduled_seq_groups) == min(
|
||||
max_seq_group, len(expected_scheduled_seq_groups))
|
||||
assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() ==
|
||||
batch_is_prompt_embeds
|
||||
for scheduled_seq_group in out.scheduled_seq_groups)
|
||||
|
||||
# Finish the scheduled groups
|
||||
for scheduled_seq_group in out.scheduled_seq_groups:
|
||||
for seq in scheduled_seq_group.seq_group.seqs:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
Reference in New Issue
Block a user