[Fix] Don't deep-copy LogitsProcessors when copying SamplingParams (#3099)

This commit is contained in:
Nick Hill
2024-02-29 11:20:42 -08:00
committed by GitHub
parent 2c08ff23c0
commit 29a8d6a554
2 changed files with 18 additions and 2 deletions

View File

@@ -1,4 +1,5 @@
"""Sampling parameters for text generation."""
import copy
from enum import IntEnum
from functools import cached_property
from typing import Callable, List, Optional, Union
@@ -237,6 +238,20 @@ class SamplingParams:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM
def clone(self) -> "SamplingParams":
"""Deep copy excluding LogitsProcessor objects.
LogitsProcessor objects are excluded because they may contain an
arbitrary, nontrivial amount of data.
See https://github.com/vllm-project/vllm/issues/3087
"""
logit_processor_refs = None if self.logits_processors is None else {
id(lp): lp
for lp in self.logits_processors
}
return copy.deepcopy(self, memo=logit_processor_refs)
def __repr__(self) -> str:
return (
f"SamplingParams(n={self.n}, "