[Bugfix] Fix llguidance backend, rollback when EOS was encountered (#25905)
Signed-off-by: Rémi Delacourt <remi@mistral.ai> Signed-off-by: remi <remi@mistral.ai> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -111,6 +111,7 @@ class GuidanceGrammar(StructuredOutputGrammar):
|
||||
vocab_size: int
|
||||
printed_error: bool = False
|
||||
terminated: bool = False
|
||||
rollback_lag: int = 0
|
||||
|
||||
def check_error(self):
|
||||
if not self.printed_error:
|
||||
@@ -127,6 +128,8 @@ class GuidanceGrammar(StructuredOutputGrammar):
|
||||
"""
|
||||
|
||||
if self.ll_tokenizer.eos_token in tokens:
|
||||
if self.ll_matcher.is_stopped() and not self.terminated:
|
||||
self.rollback_lag = 1
|
||||
self.terminated = True
|
||||
|
||||
if self.ll_matcher.is_stopped():
|
||||
@@ -163,8 +166,11 @@ class GuidanceGrammar(StructuredOutputGrammar):
|
||||
return tokens[:num_tokens]
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.ll_matcher.rollback(num_tokens)
|
||||
self.check_error()
|
||||
if num_tokens > 0:
|
||||
self.ll_matcher.rollback(num_tokens - self.rollback_lag)
|
||||
self.terminated = False
|
||||
self.rollback_lag = 0
|
||||
self.check_error()
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
# this will automatically return [EOS] mask if the matcher is stopped
|
||||
|
||||
Reference in New Issue
Block a user