[Core] Factor out common code in SequenceData and Sequence (#8675)

This commit is contained in:
Cyrus Leung
2024-09-21 10:30:39 +08:00
committed by GitHub
parent d4bf085ad0
commit 0455c46ed4
8 changed files with 64 additions and 97 deletions

View File

@@ -1,13 +1,11 @@
import itertools
from array import array
from typing import List
import pytest
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_cpu, make_tensor_with_pad
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import _get_graph_batch_size
@@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
@@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
@@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,