[Core] Factor out common code in SequenceData and Sequence (#8675)

This commit is contained in:
Cyrus Leung
2024-09-21 10:30:39 +08:00
committed by GitHub
parent d4bf085ad0
commit 0455c46ed4
8 changed files with 64 additions and 97 deletions

View File

@@ -1,4 +1,3 @@
from array import array
from itertools import count
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
@@ -11,8 +10,7 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
@@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data={
i:
SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
cont_token_ids[:]),
),
i: SequenceData.from_seqs(prompt_token_ids[:],
cont_token_ids[:]),
},
sampling_params=SamplingParams(temperature=0.0, ),
block_tables={i: block_allocations[i][:]},