[Fix] Don't deep-copy LogitsProcessors when copying SamplingParams (#3099)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Callable, List, Optional, Union
|
||||
@@ -237,6 +238,20 @@ class SamplingParams:
|
||||
return SamplingType.RANDOM_SEED
|
||||
return SamplingType.RANDOM
|
||||
|
||||
def clone(self) -> "SamplingParams":
|
||||
"""Deep copy excluding LogitsProcessor objects.
|
||||
|
||||
LogitsProcessor objects are excluded because they may contain an
|
||||
arbitrary, nontrivial amount of data.
|
||||
See https://github.com/vllm-project/vllm/issues/3087
|
||||
"""
|
||||
|
||||
logit_processor_refs = None if self.logits_processors is None else {
|
||||
id(lp): lp
|
||||
for lp in self.logits_processors
|
||||
}
|
||||
return copy.deepcopy(self, memo=logit_processor_refs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"SamplingParams(n={self.n}, "
|
||||
|
||||
Reference in New Issue
Block a user