"""Defines all structural tag formats."""
import json
from typing import Any, Dict, List, Literal, Tuple, Type, Union
try:
# Python 3.9+
from typing import Annotated
except ImportError:
# Python 3.8
from typing_extensions import Annotated
from pydantic import BaseModel, Field
# ---------- Basic Formats ----------
[docs]class AnyTextFormat(BaseModel):
"""A format that matches any text."""
type: Literal["any_text"] = "any_text"
"""The type of the format."""
excludes: List[str] = []
"""List of strings that should not appear in the matched text."""
class TokenFormat(BaseModel):
"""A format that matches a single token by ID or string representation."""
type: Literal["token"] = "token"
"""The type of the format."""
token: Union[int, str]
"""The token ID (int) or token string (str)."""
class ExcludeTokenFormat(BaseModel):
"""A format that matches a single token, excluding those in the given set."""
type: Literal["exclude_token"] = "exclude_token"
"""The type of the format."""
exclude_tokens: List[Union[int, str]] = []
"""List of token IDs or strings to exclude."""
class AnyTokensFormat(BaseModel):
"""A format that matches zero or more tokens, excluding those in the given set."""
type: Literal["any_tokens"] = "any_tokens"
"""The type of the format."""
exclude_tokens: List[Union[int, str]] = []
"""List of token IDs or strings to exclude."""
# ---------- Combinatorial Formats ----------
class TokenTriggeredTagsFormat(BaseModel):
"""A format that dispatches to tags based on token-level triggers.
Similar to TriggeredTagsFormat but uses token IDs instead of string triggers.
Tags must use ``TokenFormat`` ``begin`` fields, not strings.
For string-level dispatch, use ``TriggeredTagsFormat`` instead.
"""
type: Literal["token_triggered_tags"] = "token_triggered_tags"
"""The type of the format."""
trigger_tokens: List[Union[int, str]]
"""The trigger token IDs or strings."""
tags: List[TagFormat]
"""The tags to dispatch to."""
exclude_tokens: List[Union[int, str]] = []
"""List of token IDs or strings to exclude."""
at_least_one: bool = False
"""Whether at least one tag must be generated."""
stop_after_first: bool = False
"""Whether to stop after the first tag is generated."""
class OptionalFormat(BaseModel):
"""A format that matches the content 0 or 1 time (EBNF optional).
Semantics: the inner format may appear once or not at all.
"""
type: Literal["optional"] = "optional"
"""The type of the format."""
content: "Format"
"""The format that may appear 0 or 1 time."""
class PlusFormat(BaseModel):
"""A format that matches the content 1 or more times (EBNF plus).
Semantics: the inner format must appear at least once.
"""
type: Literal["plus"] = "plus"
"""The type of the format."""
content: "Format"
"""The format that must appear at least once."""
class StarFormat(BaseModel):
"""A format that matches the content 0 or more times (EBNF star).
Semantics: the inner format may appear any number of times.
"""
type: Literal["star"] = "star"
"""The type of the format."""
content: "Format"
"""The format that may appear 0 or more times."""
class RepeatFormat(BaseModel):
"""A format that matches the content between min and max times (inclusive).
Use max=-1 for unbounded upper limit (e.g. "at least min times").
"""
type: Literal["repeat"] = "repeat"
"""The type of the format."""
min: int
"""Minimum number of occurrences (inclusive)."""
max: int
"""Maximum number of occurrences (inclusive). Use -1 for unbounded."""
content: "Format"
"""The format that is repeated."""
class DispatchFormat(BaseModel):
"""Matches certain patterns in free-form text.
When certain strings are generated, the following content must follow the corresponding format.
Certain strings can be excluded from being generated in the free-form part.
The user specifies a list of (pattern, formats). The LLM can generate any free-form text, but
when a pattern is matched in the text, the following output must follow the corresponding format.
The ``loop`` field controls whether the matching of this structure ends after accepting the first
pattern and format. If False, it ends. If true, the matching continues, and this format will
continue to allow free-form text and detect the next pattern to be generated.
The ``excludes`` field controls the strings that cannot be generated in the free-form text.
It can also control the end of the format:
``SequenceFormat(DispatchFormat(..., excludes=""), ConstStringFormat(""))``
or
``TagFormat(begin=..., content=DispatchFormat(..., excludes=""), end="")``
ends the matching of the format when LLM generates <end>.
"""
type: Literal["dispatch"] = "dispatch"
"""The type of the format."""
rules: List[Tuple[str, "Format"]]
"""List of ``(pattern, content format)`` pairs."""
loop: bool = True
"""If true, after handling one dispatched format, it will continue to allow free-form text
and match the next pattern. Otherwise, the matching of this format ends after handling the
first dispatched format."""
excludes: List[str] = []
"""List of strings that must not appear in the free-form text."""
class TokenDispatchFormat(BaseModel):
"""Matches certain patterns in free-form text.
When certain tokens are generated, the following content must follow the corresponding format.
Certain tokens can be excluded from being generated in the free-form part.
The user specifies a list of (pattern token, formats). The LLM can generate any free-form text, but
when a pattern is matched in the text, the following output must follow the corresponding format.
The ``loop`` field controls whether the matching of this structure ends after accepting the first
pattern and format. If False, it ends. If true, the matching continues, and this format will
continue to allow free-form text and detect the next pattern to be generated.
The ``excludes`` field controls the strings that cannot be generated in the free-form text.
It can also control the end of the format:
``SequenceFormat(DispatchFormat(..., excludes=""), ConstStringFormat(""))``
or
``TagFormat(begin=..., content=DispatchFormat(..., excludes=""), end="")``
ends the matching of the format when LLM generates <end>.
"""
type: Literal["token_dispatch"] = "token_dispatch"
"""The type of the format."""
rules: List[Tuple[Union[int, str], "Format"]]
"""List of ``(pattern token, content format)`` pairs. Pattern is token ID or token string."""
loop: bool = True
"""If true, after one dispatched format, it will continue to allow free-form text
and match the next pattern. Otherwise, the matching of this format ends after handling the
first dispatched format."""
exclude_tokens: List[Union[int, str]] = []
"""List of tokens that must not appear in the free-form text."""
# ---------- Discriminated Union ----------
Format = Annotated[
Union[
AnyTextFormat,
ConstStringFormat,
JSONSchemaFormat,
GrammarFormat,
RegexFormat,
QwenXMLParameterFormat,
OrFormat,
SequenceFormat,
TagFormat,
TriggeredTagsFormat,
TokenTriggeredTagsFormat,
TagsWithSeparatorFormat,
OptionalFormat,
PlusFormat,
StarFormat,
TokenFormat,
ExcludeTokenFormat,
AnyTokensFormat,
RepeatFormat,
DispatchFormat,
TokenDispatchFormat,
],
Field(discriminator="type"),
]
"""Union of all structural tag formats."""
# Solve forward references
if hasattr(BaseModel, "model_rebuild"):
SequenceFormat.model_rebuild()
TagFormat.model_rebuild()
TriggeredTagsFormat.model_rebuild()
TokenTriggeredTagsFormat.model_rebuild()
TagsWithSeparatorFormat.model_rebuild()
OptionalFormat.model_rebuild()
PlusFormat.model_rebuild()
StarFormat.model_rebuild()
RepeatFormat.model_rebuild()
DispatchFormat.model_rebuild()
TokenDispatchFormat.model_rebuild()
elif hasattr(BaseModel, "update_forward_refs"):
SequenceFormat.update_forward_refs()
TagFormat.update_forward_refs()
TriggeredTagsFormat.update_forward_refs()
TokenTriggeredTagsFormat.update_forward_refs()
TagsWithSeparatorFormat.update_forward_refs()
OptionalFormat.update_forward_refs()
PlusFormat.update_forward_refs()
StarFormat.update_forward_refs()
RepeatFormat.update_forward_refs()
DispatchFormat.update_forward_refs()
TokenDispatchFormat.update_forward_refs()
else:
raise RuntimeError("Unsupported pydantic version")
# ---------- Top Level ----------
[docs]class StructuralTagItem(BaseModel):
"""Deprecated. Definition of a structural tag item.
See :meth:`xgrammar.Grammar.from_structural_tag` for more details.
"""
begin: str
"""The begin tag."""
schema_: Union[str, Type[BaseModel], Dict[str, Any]] = Field(alias="schema")
"""The schema."""
end: str
"""The end tag."""
[docs]class StructuralTag(BaseModel):
"""
Describes a complete structural tag structure. It corresponds to
``"response_format": {"type": "structural_tag", "format": {...}}`` in API.
"""
type: Literal["structural_tag"] = "structural_tag"
"""The type must be "structural_tag"."""
format: Format
"""The format of the structural tag. Could be any of the structural tag formats."""
[docs] @staticmethod
def from_legacy_structural_tag(
tags: List[StructuralTagItem], triggers: List[str]
) -> "StructuralTag":
"""Convert a legacy structural tag item to a structural tag."""
return StructuralTag(
type="structural_tag",
format=TriggeredTagsFormat(
type="triggered_tags",
triggers=triggers,
tags=[
TagFormat(
begin=tag.begin,
content=JSONSchemaFormat(
json_schema=(
json.loads(tag.schema_)
if isinstance(tag.schema_, str)
else (
tag.schema_.model_json_schema()
if isinstance(tag.schema_, type)
and issubclass(tag.schema_, BaseModel)
else tag.schema_
)
)
),
end=tag.end,
)
for tag in tags
],
),
)
[docs] @staticmethod
def from_json(json_str: Union[str, Dict[str, Any]]) -> "StructuralTag":
"""Convert a JSON string to a structural tag."""
if isinstance(json_str, str):
return StructuralTag.model_validate_json(json_str)
elif isinstance(json_str, dict):
return StructuralTag.model_validate(json_str)
else:
raise ValueError("Invalid JSON string or dictionary")
__all__ = [
"ConstStringFormat",
"JSONSchemaFormat",
"QwenXMLParameterFormat",
"AnyTextFormat",
"GrammarFormat",
"RegexFormat",
"TokenFormat",
"ExcludeTokenFormat",
"AnyTokensFormat",
"SequenceFormat",
"OrFormat",
"TagFormat",
"TriggeredTagsFormat",
"TokenTriggeredTagsFormat",
"TagsWithSeparatorFormat",
"OptionalFormat",
"PlusFormat",
"StarFormat",
"RepeatFormat",
"Format",
"StructuralTagItem",
"StructuralTag",
]