Add support for GPT-2 (#60)

This commit is contained in:
Woosuk Kwon
2023-05-04 02:59:56 -07:00
committed by GitHub
parent 130d5fd8c7
commit e548c1488a
7 changed files with 350 additions and 8 deletions

View File

@@ -11,8 +11,9 @@ from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_pa
class Sampler(nn.Module):
def __init__(self) -> None:
def __init__(self, vocab_size: int) -> None:
super().__init__()
self.vocab_size = vocab_size
def forward(
self,
@@ -26,6 +27,8 @@ class Sampler(nn.Module):
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab.
logits = logits[:, :self.vocab_size]
# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)