TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
# pylint: disable=protected-access
|
||||
import pytest
|
||||
import random
|
||||
from typing import Tuple
|
||||
@@ -108,7 +109,7 @@ def test_sampler_all_random(seed: int):
|
||||
def test_sampler_all_beam(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
input_tensor, _, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
|
||||
Reference in New Issue
Block a user