"""Match the output of the LLM to the specified grammar, then generate the mask for the next
token.
"""
import math
from typing import List, Optional, Tuple, Union
import torch
from .base import XGRObject, _core
from .compiler import CompiledGrammar
bitmask_dtype = torch.int32
"""The dtype of the bitmask: int32."""
[docs]def get_bitmask_shape(batch_size: int, vocab_size: int) -> Tuple[int, int]:
"""Return the shape of the bitmask: (batch_size, ceil(vocab_size / 32))."""
return (batch_size, math.ceil(vocab_size / 32))
_FULL_MASK = torch.tensor(-1, dtype=bitmask_dtype)
[docs]def allocate_token_bitmask(batch_size: int, vocab_size: int) -> torch.Tensor:
"""Allocate the bitmask for the next token prediction. The bitmask is an int32 tensor on
CPU with shape (batch_size, ceil(vocab_size / 32)). Users who have their own needs to
manage CUDA memory can construct the tensor with get_bitmask_shape and bitmask_dtype
themselves.
The reason why we use int32 instead of uint32 is that old versions of PyTorch do not support
uint32.
Parameters
----------
batch_size : int
The batch size of the bitmask.
vocab_size : int
The size of the vocabulary.
Returns
-------
bitmask : torch.Tensor
The shape of the bitmask.
"""
# In CUDA, use pinned memory to speed up data transfer from CPU to GPU
return torch.full(get_bitmask_shape(batch_size, vocab_size), _FULL_MASK, dtype=bitmask_dtype)
[docs]def reset_token_bitmask(bitmask: torch.Tensor) -> None:
"""Reset the bitmask to the full mask."""
bitmask.fill_(_FULL_MASK)
[docs]def apply_token_bitmask_inplace(
logits: torch.Tensor,
bitmask: torch.Tensor,
*,
vocab_size: Optional[int] = None,
indices: Optional[List[int]] = None,
) -> None:
"""Apply the bitmask to the logits in-place. The bitmask is a 01 bitwise compressed tensor,
where 0 means the token is masked and 1 means the token is not masked. It can be generated by
allocate_token_bitmask and filled by fill_next_token_bitmask. After applying the bitmask, the
masked logits will be set to -inf.
The shape of logits and bitmask should be (batch_size, vocab_size) and
(batch_size, bitmask_size) respectively. bitmask_size = ceil(vocab_size / 32). The operation is:
.. code:: python
for i in range(batch_size):
for j in range(vocab_size):
if get_bitmask_value(bitmask, i, j) == 0:
logits[i, j] = -inf
get_bitmask_value(bitmask, i, j) gets the j-th bit of the i-th row of the bitmask.
Notes
-----
Padding:
This method allows additional padding on the vocabulary dimension of logits or bitmask. If
padding exists, provide the real vocab size to the vocab_size parameter, and the operation
will be applied to logits[..., :vocab_size] and bitmask[..., :ceil(vocab_size / 32)].
If vocab_size is not provided, the vocab size will be detected as min(logits.shape[-1],
bitmask.shape[-1] * 32).
Indices:
Indices can be used to specify which logits in the batch to apply the bitmask to. It is
especially useful when there are structured requests and unstructured requests mixed in the
same batch by skipping masking the logits in the unstructured requests. When specified, the
operation will be
.. code:: python
for batch_id in indices:
for j in range(vocab_size):
if get_bitmask_value(bitmask, batch_id, j) == 0:
logits[batch_id, j] = -inf
When indices is specified, the batch sizes of logits and bitmask do not need to be the same.
As long as the indices are valid, the operation will be performed.
Device:
The logits and bitmask should be on the same device. If both them are on GPU, we launch a GPU
kernel to apply bitmask. If both them are on CPU, we use a CPU implementation. The GPU kernel
is optimized and should be preferred.
In practice, the bitmask is allocated on CPU, and the logits is usually on GPU, so users should
manually copy the bitmask to GPU before calling this function.
Parameters
----------
logits : torch.Tensor
The tensor to apply the bitmask to.
bitmask : torch.Tensor
The bitmask to apply.
vocab_size : Optional[int], default: None
The size of the vocabulary. If not provided, the vocab size will be detected as
min(logits.shape[-1], bitmask.shape[-1] * 32).
indices : Optional[List[int]], default: None
A list of indices to specify which logits in the batch to apply the bitmask to. Should be
unique. If None, apply the bitmask to all logits in the batch.
"""
if bitmask.device != logits.device:
raise ValueError(
"logits and bitmask should be on the same device. "
+ f"But got logits.device: {logits.device}, bitmask.device: {bitmask.device}"
)
# dispatch to different implementations based on the device
if logits.device.type == "cpu":
from .kernels.apply_token_bitmask_inplace_cpu import apply_token_bitmask_inplace_cpu
apply_token_bitmask_inplace_cpu(logits, bitmask, vocab_size, indices)
elif logits.device.type == "cuda":
from .kernels.apply_token_bitmask_inplace_triton import apply_token_bitmask_inplace_triton
apply_token_bitmask_inplace_triton(logits, bitmask, vocab_size, indices)
else:
from .kernels.apply_token_bitmask_inplace_torch_compile import (
apply_token_bitmask_inplace_torch_compile,
)
apply_token_bitmask_inplace_torch_compile(logits, bitmask, vocab_size, indices)
[docs]class GrammarMatcher(XGRObject):
"""Match the output of the LLM to the specified grammar, then generate the mask for the next
token. This is the core class in the grammar-guided generation.
This class maintains a stateful matcher that can accept tokens and strings, then match them
to the specified grammar. The matcher can provide a bitmask for the next token prediction,
so that the output of the LLM follows the specified grammar. Its state can be reset and
rolled back by tokens. It also provides utilities for jump-forward decoding.
After matching the whole grammar, the matcher will accept a stop token. The token mask at
this time will only allow stop tokens. After accepting the stop token, the matcher will
terminate, then it cannot accept any new token or generate a new token mask, meaning the
generation is finished.
Under the hood, it utilizes a pushdown automaton with backtracking to match the grammar,
with optimizations specific to LLM token mask generation.
Parameters
----------
compiled_grammar : CompiledGrammar
The initialization context for the grammar matcher.
override_stop_tokens : Optional[Union[int, List[int]]], default: None
If not None, the stop tokens to override the ones in the grammar.
terminate_without_stop_token : bool, default: False
Whether to terminate the matcher without accepting a stop token.
max_rollback_tokens : int, default: 0
The maximum number of rollback tokens allowed. The rollback operation is useful for
jump-forward decoding and speculative decoding.
"""
[docs] def __init__(
self,
compiled_grammar: CompiledGrammar,
*,
override_stop_tokens: Optional[Union[int, List[int]]] = None,
terminate_without_stop_token: bool = False,
max_rollback_tokens: int = 0,
) -> None:
if not isinstance(compiled_grammar, CompiledGrammar):
raise ValueError("The grammar should be compiled before passing it to GrammarMatcher.")
if isinstance(override_stop_tokens, int):
override_stop_tokens = [override_stop_tokens]
self._init_handle(
_core.GrammarMatcher(
compiled_grammar._handle,
override_stop_tokens,
terminate_without_stop_token,
max_rollback_tokens,
)
)
[docs] def accept_token(self, token_id: int, *, debug_print: bool = False) -> bool:
"""Accept one token and update the state of the matcher.
In the following cases, the matcher will not accept the token and return False:
1. The token does not match the grammar.
2. The matcher has terminated after accepting the stop token, but is trying to accept a
new token.
3. The token id is out of range.
4. The token is a special token.
The user should capture the return value and handle the cases where the token is not
accepted.
Parameters
----------
token_id : int
The id of the token to accept.
debug_print : bool, default: False
Whether to print information about the internal state of the matcher. Helpful
for debugging.
Returns
-------
accepted : bool
Whether the token is accepted.
Raises
------
RuntimeError
If the recursion depth is exceeded.
"""
return self._handle.accept_token(token_id, debug_print)
[docs] def accept_string(self, input_str: Union[str, bytes], *, debug_print: bool = False) -> bool:
"""Accept a string and update the state of the matcher. The whole string is considered
as one step in rollback. It is used to complement the functionality of accept_token, and
accept_token should always be used to accept tokens.
Parameters
----------
input_str : Union[str, bytes]
The string to be accepted.
debug_print : bool, default: False
Whether to print information about the internal state of the matcher. Helpful for
debugging.
Returns
-------
accepted : bool
Whether the string is accepted.
Raises
------
RuntimeError
If the recursion depth is exceeded.
"""
return self._handle.accept_string(input_str, debug_print)
[docs] def fill_next_token_bitmask(
self, bitmask: torch.Tensor, index: int = 0, *, debug_print: bool = False
) -> bool:
"""Fill the bitmask for the next token prediction. The input bitmask can be generated
by allocate_token_bitmask, and must be on CPU. bitmask[index] will be filled with the
next token bitmask.
This method does not change the matcher state.
Parameters
----------
bitmask : torch.Tensor
The bitmask for the next token prediction.
index : int, default: 0
The batch id of the bitmask.
debug_print : bool, default: False
Whether to print information about generated bitmask. Helpful for debugging.
Returns
-------
need_apply : bool
Whether the bitmask need to be applied (not all-true). An optimization: if False,
this means the bitmask is already all-true, so no need to apply it.
Raises
------
RuntimeError
If the recursion depth is exceeded.
"""
if bitmask.device.type != "cpu":
raise ValueError("bitmask should be on CPU.")
if bitmask.dtype != bitmask_dtype:
raise ValueError(f"bitmask should be of type {bitmask_dtype}.")
return self._handle.fill_next_token_bitmask(
bitmask.data_ptr(), list(bitmask.shape), index, debug_print
)
[docs] def find_jump_forward_string(self) -> str:
"""Find the jump-forward string for jump-forward decoding. This is the longest string that
certainly conforms with the current grammar from the current matcher state. This string
can become the output of the LLM without requiring LLM decoding.
This method does not change the matcher state.
Returns
-------
jump_forward_string : str
The jump-forward string.
Raises
------
RuntimeError
If the recursion depth is exceeded.
"""
return self._handle.find_jump_forward_string()
[docs] def rollback(self, num_tokens: int = 1) -> None:
"""Rollback the matcher to a previous state by several tokens.
Parameters
----------
num_tokens : int, default: 1
The number of tokens to rollback. It cannot exceed the current number of steps, nor can
it exceed the specified maximum number of rollback tokens.
"""
self._handle.rollback(num_tokens)
[docs] def is_terminated(self) -> bool:
"""Check if the matcher has terminated. If terminate_without_stop_token is False, the
matcher will terminate if it has accepted the stop token. Otherwise, the matcher will
terminate after matching the whole grammar.
Returns
-------
terminated : bool
Whether the matcher has terminated.
"""
return self._handle.is_terminated()
[docs] def reset(self) -> None:
"""Reset the matcher to the initial state."""
return self._handle.reset()
@property
def max_rollback_tokens(self) -> int:
"""Get the maximum number of rollback tokens allowed.
Returns
-------
max_rollback_tokens : int
The maximum number of rollback tokens.
"""
return self._handle.max_rollback_tokens
@property
def stop_token_ids(self) -> List[int]:
"""The ids of the stop tokens used in the matcher. If specified, the provided stop tokens
will be used. Otherwise, the stop tokens will be detected from the vocabulary.
Returns
-------
stop_token_ids : List[int]
The ids of the stop tokens.
"""
return self._handle.stop_token_ids
def _debug_print_internal_state(self) -> str:
"""Print the internal state of the matcher. This is used for debugging. The
representation of the internal state is subject to change.
Returns
-------
internal_state : str
The internal state of the matcher.
"""
return self._handle._debug_print_internal_state()