[Misc] Validate grammar and fail early (#11119)
This commit is contained in:
@@ -131,12 +131,6 @@ class GrammarConfig:
|
|||||||
max_threads: int = 8) -> GrammarConfig:
|
max_threads: int = 8) -> GrammarConfig:
|
||||||
|
|
||||||
tokenizer_hash = hash(tokenizer)
|
tokenizer_hash = hash(tokenizer)
|
||||||
# Only get tokenizer data if not already cached
|
|
||||||
if tokenizer_hash in TokenizerDataCache._cache:
|
|
||||||
encoded_vocab = None
|
|
||||||
stop_token_ids = None
|
|
||||||
backend_str = None
|
|
||||||
else:
|
|
||||||
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
|
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
|
||||||
encoded_vocab = tokenizer_data.encoded_vocab
|
encoded_vocab = tokenizer_data.encoded_vocab
|
||||||
stop_token_ids = tokenizer_data.stop_token_ids
|
stop_token_ids = tokenizer_data.stop_token_ids
|
||||||
@@ -147,6 +141,15 @@ class GrammarConfig:
|
|||||||
json_str = json.dumps(guided_params.json)
|
json_str = json.dumps(guided_params.json)
|
||||||
else:
|
else:
|
||||||
json_str = guided_params.json
|
json_str = guided_params.json
|
||||||
|
|
||||||
|
# Validate the schema and raise ValueError here if it is invalid.
|
||||||
|
# This is to avoid exceptions in model execution, which will crash
|
||||||
|
# the engine worker process.
|
||||||
|
try:
|
||||||
|
xgr.Grammar.from_json_schema(json_str)
|
||||||
|
except RuntimeError as err:
|
||||||
|
raise ValueError(str(err)) from err
|
||||||
|
|
||||||
return cls(json_str=json_str,
|
return cls(json_str=json_str,
|
||||||
vocab_size=model_config.hf_text_config.vocab_size,
|
vocab_size=model_config.hf_text_config.vocab_size,
|
||||||
encoded_vocab=encoded_vocab,
|
encoded_vocab=encoded_vocab,
|
||||||
@@ -167,6 +170,15 @@ class GrammarConfig:
|
|||||||
f"Conversion error: {str(e)}") from e
|
f"Conversion error: {str(e)}") from e
|
||||||
else:
|
else:
|
||||||
grammar_str = guided_params.grammar
|
grammar_str = guided_params.grammar
|
||||||
|
|
||||||
|
# Validate the grammar and raise ValueError here if it is invalid.
|
||||||
|
# This is to avoid exceptions in model execution, which will crash
|
||||||
|
# the engine worker process.
|
||||||
|
try:
|
||||||
|
xgr.Grammar.from_ebnf(grammar_str)
|
||||||
|
except RuntimeError as err:
|
||||||
|
raise ValueError(str(err)) from err
|
||||||
|
|
||||||
return cls(grammar_str=grammar_str,
|
return cls(grammar_str=grammar_str,
|
||||||
vocab_size=model_config.hf_text_config.vocab_size,
|
vocab_size=model_config.hf_text_config.vocab_size,
|
||||||
encoded_vocab=encoded_vocab,
|
encoded_vocab=encoded_vocab,
|
||||||
|
|||||||
@@ -26,16 +26,12 @@ def grammar_is_likely_lark(grammar_str: str) -> bool:
|
|||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Look for Lark-style rule definitions
|
# Look for GBNF rule definition
|
||||||
if ':' in line and '::=' not in line:
|
if '::=' in line:
|
||||||
return True
|
|
||||||
|
|
||||||
# Look for Lark-specific features
|
|
||||||
if any(pattern in line for pattern in ['?start:', '|', '~']):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def convert_lark_to_gbnf(grammar_str: str) -> str:
|
def convert_lark_to_gbnf(grammar_str: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user