TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)

This commit is contained in:
Zhuohan Li
2023-10-02 15:36:09 -07:00
committed by GitHub
parent 84e4e37d14
commit ba0bfd40e2
42 changed files with 819 additions and 1547 deletions

View File

@@ -1,3 +1,4 @@
# pylint: disable=protected-access
import pytest
import random
from typing import Tuple
@@ -108,7 +109,7 @@ def test_sampler_all_random(seed: int):
def test_sampler_all_beam(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
input_tensor, _, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
for i in range(batch_size):