[Bugfix] Fix OpenAI parallel sampling when using xgrammar (#11637)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@@ -450,15 +450,16 @@ class SamplingParams(
|
||||
return self._all_stop_token_ids
|
||||
|
||||
def clone(self) -> "SamplingParams":
|
||||
"""Deep copy excluding LogitsProcessor objects.
|
||||
"""Deep copy, but maybe not the LogitsProcessor objects.
|
||||
|
||||
LogitsProcessor objects are excluded because they may contain an
|
||||
arbitrary, nontrivial amount of data.
|
||||
LogitsProcessor objects may contain an arbitrary, nontrivial amount of
|
||||
data that is expensive to copy. However, if not copied, the processor
|
||||
needs to support parallel decoding for multiple sequences
|
||||
See https://github.com/vllm-project/vllm/issues/3087
|
||||
"""
|
||||
|
||||
logit_processor_refs = None if self.logits_processors is None else {
|
||||
id(lp): lp
|
||||
id(lp): lp.clone() if hasattr(lp, 'clone') else lp
|
||||
for lp in self.logits_processors
|
||||
}
|
||||
return copy.deepcopy(self, memo=logit_processor_refs)
|
||||
|
||||
Reference in New Issue
Block a user