[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import random
|
||||
from array import array
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@@ -10,7 +11,8 @@ from transformers import GenerationConfig, GenerationMixin
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.utils import Counter, is_pin_memory_available
|
||||
|
||||
|
||||
@@ -56,7 +58,9 @@ def _do_sample(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
seq_data={
|
||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
||||
},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
@@ -201,7 +205,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
|
||||
def create_sequence_data(num_input=3, num_generated=0):
|
||||
seq_data = SequenceData(
|
||||
random.choices(range(0, VOCAB_SIZE), k=num_input))
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
random.choices(range(0, VOCAB_SIZE), k=num_input)))
|
||||
if num_generated > 0:
|
||||
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
|
||||
k=num_generated)
|
||||
@@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
seq_data={
|
||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
||||
},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
@@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
seq_data={
|
||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
||||
},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1,
|
||||
top_k=top_k,
|
||||
@@ -650,7 +659,11 @@ def test_sampler_repetition_penalty_mixed(device: str):
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
seq_data={
|
||||
0:
|
||||
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[1, 2, 3]))
|
||||
},
|
||||
sampling_params=sampling_params[i],
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
Reference in New Issue
Block a user