Use runtime profiling to replace manual memory analyzers (#81)

This commit is contained in:
Zhuohan Li
2023-05-19 11:35:44 -06:00
committed by GitHub
parent 825d8892b5
commit f756799b84
14 changed files with 211 additions and 478 deletions

View File

@@ -74,7 +74,7 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks):
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
# Sample the next tokens.