Support beam search & parallel generation (#7)

This commit is contained in:
Woosuk Kwon
2023-03-10 09:58:21 -08:00
committed by GitHub
parent 04e5acc08e
commit 1a7eb7da61
16 changed files with 660 additions and 161 deletions

View File

@@ -1,5 +1,7 @@
import random
from typing import Union
import numpy as np
import torch
import torch.nn as nn
@@ -30,3 +32,11 @@ def get_model(
model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype)
return model.eval()
raise ValueError(f'Invalid model name: {model_name}')
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)