[SpecDec] Remove Batch Expansion (2/3) (#9298)

This commit is contained in:
Lily Liu
2024-10-11 22:13:37 -07:00
committed by GitHub
parent ec10cb8511
commit 89feb4c84d
8 changed files with 122 additions and 70 deletions

View File

@@ -1,3 +1,6 @@
import random
from typing import List
import pytest
import torch
@@ -10,31 +13,45 @@ from vllm.worker.worker import Worker
from .utils import create_batch, create_worker
def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
def create_proposal(propose_lens: List[int], vocab_size: int,
device: str) -> SpeculativeProposals:
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
batch_size = len(propose_lens)
max_propose_len = max(propose_lens)
proposal_probs = torch.rand((batch_size, max_propose_len, vocab_size),
device=device)
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
proposal_token_ids = torch.full((batch_size, max_propose_len),
fill_value=-1,
device=device)
for i in range(batch_size):
proposal_token_ids[i][:propose_lens[i]] = torch.argmax(
proposal_probs[i][:propose_lens[i]], dim=-1)
propose_lens = torch.tensor(propose_lens, device=device)
return SpeculativeProposals(proposal_token_ids, proposal_probs,
proposal_lens)
propose_lens)
def assert_score_equal(score1: SpeculativeScores,
score2: SpeculativeScores) -> None:
assert torch.allclose(score1.probs, score2.probs)
assert torch.allclose(score1.logprobs, score2.logprobs)
assert torch.equal(score1.token_ids, score2.token_ids)
assert torch.equal(
score1.token_ids,
score2.token_ids), f"{score1.token_ids}, {score2.token_ids}"
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
@pytest.mark.parametrize('propose_len', [1, 3, 5])
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
@pytest.mark.parametrize('mixed_propose_len', [True])
@pytest.mark.parametrize('device', ['cuda'])
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
device: str) -> None:
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
mixed_propose_len: bool, device: str) -> None:
"""
Compare the batch expansion scorer and mqa scorer return the same score
Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different
propose length.
"""
seed = 0
block_size = 32
@@ -46,13 +63,22 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int,
should_modify_greedy_probs_inplace = True
vocab_size = scorer_worker.vocab_size
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
if not mixed_propose_len:
propose_lens = [max_propose_len] * batch_size
else:
non_zero_cnt = random.randint(0, batch_size)
propose_lens = [max_propose_len
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
random.shuffle(propose_lens)
proposals = create_proposal(propose_lens, vocab_size, device)
seq_group_metadatalist, _, _ = create_batch(batch_size,
propose_len,
max_propose_len,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=propose_len)
num_lookahead_slots=max_propose_len)
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
vocab_size)