Workflow of XGrammar¶
This tutorial introduces the workflow of XGrammar, including most of its core components. Please read constrained decoding first to understand how XGrammar achieves structured generation.
import xgrammar as xgr
import asyncio
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
Grammar¶
xgr.Grammar
describes the structure of the LLM output. It can be:
A JSON schema or free-form JSON
A regex
A customized context-free grammar in the extended BNF format
etc.
To construct a grammar, use
grammar: xgr.Grammar = xgr.Grammar.from_json_schema(json_schema_string)
# or
grammar: xgr.Grammar = xgr.Grammar.builtin_json_grammar()
# or
grammar: xgr.Grammar = xgr.Grammar.from_regex(regex_string)
# or
grammar: xgr.Grammar = xgr.Grammar.from_ebnf(ebnf_string)
print(grammar) # print the ebnf format of the grammar
Tokenizer Info¶
xgr.TokenizerInfo
contains the tokenizer information of the model.
It is necessary for XGrammar to generate the token mask.
xgr.TokenizerInfo
can be constructed from a HuggingFace tokenizer, or from a list of raw tokens.
For HuggingFace tokenizers, XGrammar supports HuggingFace’s fast tokenizer,
tiktoken, and SentencePiece
tokenizers as the backend.
tokenizer = AutoTokenizer.from_pretrained(...)
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
Grammar Compiler¶
To accelerate mask generation, XGrammar performs preprocessing on the grammar using the vocabulary of the model. This process is called Grammar Compilation. During grammar compilation, we:
Simplify the grammar and build automata
Compute an adaptive token mask cache. It will be used at runtime to generate the real mask
xgr.GrammarCompiler
processes the grammar and produces a xgr.CompiledGrammar
object. Each xgr.GrammarCompiler
is bound to a specific xgr.TokenizerInfo
object. When given a grammar, it uses this tokenizer info to compile it. You can pass in a Grammar object directly, or provide a raw EBNF string, JSON Schema, or regex pattern:
grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
compiled_grammar = grammar_compiler.compile_grammar(grammar)
# or
compiled_grammar = grammar_compiler.compile_json_schema(json_schema_string)
# or
compiled_grammar = grammar_compiler.compile_builtin_json_grammar()
# or
compiled_grammar = grammar_compiler.compile_regex(regex_string)
# or
compiled_grammar = grammar_compiler.compile_grammar(ebnf_string)
Compiled Grammar¶
A xgr.CompiledGrammar
object is associated with an xgr.Grammar
object and an xgr.TokenizerInfo
object. It contains the compiled grammar and the token mask cache. Use these methods to access the grammar and tokenizer info:
compiled_grammar.grammar
compiled_grammar.tokenizer_info
Token Bitmask¶
The mask is a bool tensor with the same shape as the vocabulary size. XGrammar further compresses the mask into a int32 bitset to save memory. It also support batch settings. Use xgr.allocate_token_bitmask
to allocate a bitmask:
bitmask = xgr.allocate_token_bitmask(batch_size, tokenizer_info.vocab_size)
The bitmask is a torch.Tensor
with the shape (batch_size, ceil(vocab_size / 32))
, dtype int32
and device cpu
. It is located on CPU because we will further fill the bitmask with CPU logic.
Grammar Matcher¶
xgr.GrammarMatcher
handles the logic of matching the LLM output to the structure and generating the token mask. It is constructed with a xgr.CompiledGrammar
object. In each step, it will accept the last token generated by the LLM with xgr.GrammarMatcher.accept_token
, and then generate the mask with xgr.GrammarMatcher.fill_next_token_bitmask
.
grammar_matcher = xgr.GrammarMatcher(compiled_grammar)
token_id = ... # the last token generated by the LLM
grammar_matcher.accept_token(token_id)
grammar_matcher.fill_next_token_bitmask(bitmask)
Use xgr.apply_token_bitmask_inplace
to apply the bitmask to the logits of the LLM. It will modify the logits in place. If the logits is on GPU, the bitmask should be moved to the same device first.
logits = ... # the logits of the LLM
xgr.apply_token_bitmask_inplace(logits, bitmask.to(logits.device))
# Sample the next token
prob = torch.softmax(logits, dim=-1)
next_token_id = torch.argmax(prob, dim=-1)
print(tokenizer.decode(next_token_id))
The Generation Loop¶
With the above introduction, it’s easy for us to write a generation loop using Hugging Face Transformers. xgr.GrammarMatcher.is_terminated
is provided to check if the matcher is terminated.
input_ids: list(int) = ... # the input tokens
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
model = AutoModelForCausalLM.from_pretrained(...)
while not grammar_matcher.is_terminated():
# Generate the logits. Shape: (1, seq_len, vocab_size)
logits = model(torch.tensor([input_ids], device=torch.device("cuda"))).logits
# Fill and apply the bitmask
grammar_matcher.fill_next_token_bitmask(token_bitmask)
xgr.apply_token_bitmask_inplace(logits[0, -1, :], token_bitmask.to(logits.device))
# Sample the next token
prob = torch.softmax(logits[0, -1, :], dim=-1)
next_token_id = torch.argmax(prob, dim=-1)
# Accept the token and append it to the input
grammar_matcher.accept_token(next_token_id.item())
input_ids.append(next_token_id.item())
# Reset the matcher so it can be used again
grammar_matcher.reset()
# Print the generated text
print(tokenizer.decode(input_ids))
xgr.GrammarMatcher.is_terminated
will require the LLM generate
an EOS token after completing the structure. It is equivalent to the EOS-terminated generation.
Congratulations! You have successfully generated a structured output using XGrammar.
Next Steps¶
Read advanced topics to learn more advanced features about XGrammar. Read integration with LLM engine to learn how to integrate XGrammar into an LLM engine.