Tree Annotations¶
Tree annotations validate nested container structures such as dicts, lists, tuples, namedtuples, and other pytree-compatible objects.
Import Tree from an explicit backend module:
The root bearshape module exports Structure, T, and S, but not Tree
itself.
Basic leaf checking¶
Tree[LeafType] means "every leaf in the pytree must satisfy LeafType".
import numpy as np
from beartype import beartype
from bearshape import N, C
from bearshape.numpy import F32
from bearshape.optree import Tree
@beartype
def process(data: Tree[F32[N, C]]) -> Tree[F32[N, C]]:
...
process({
"params": np.ones((3, 4), dtype=np.float32),
"state": np.ones((3, 4), dtype=np.float32),
})
Dimension bindings are shared across the whole tree, so N and C must agree
across all leaves.
Structure binding¶
Named structure symbols (T, S) enforce that multiple arguments share
identical tree shapes:
import numpy as np
from beartype import beartype
from bearshape import N, T
from bearshape.numpy import F32
from bearshape.optree import Tree
@beartype
def add_trees(
x: Tree[F32[N], T],
y: Tree[F32[N], T],
) -> Tree[F32[N]]: # type: ignore[valid-type]
...
Structure symbols are runtime-only. Static type checkers understand
Tree[F32[N]], but not the extra structure arguments, so those function
signatures need a targeted # type: ignore.
Multi-level structure matching¶
Structure names are interpreted from outer to inner unless ... changes the
direction or truncates the match.
Full structure binding¶
Top-level only¶
Trailing ... makes each name capture only one level, with inner levels
unchecked:
Bottom-level only¶
Leading ... matches names from the bottom up:
Two-level matching¶
@beartype
def f(x: Tree[int, T, S], y: Tree[int, T, S]): # type: ignore[valid-type]
...
@beartype
def g(x: Tree[F32[N], T, S, ...]): # type: ignore[valid-type]
...
@beartype
def h(x: Tree[F32[N], ..., T, S]): # type: ignore[valid-type]
...
Custom structure symbols¶
Create your own with Structure:
from beartype import beartype
from bearshape import N, Structure
from bearshape.numpy import F32, I64
from bearshape.optree import Tree
Params = Structure("Params")
State = Structure("State")
@beartype
def train(params: Tree[F32[N], Params],
state: Tree[I64[N], State]): # type: ignore[valid-type]
...
Static typing split¶
Checker-friendly:
Tree[object]Tree[int]Tree[F32[N]]Tree[F32[N, C]]
Runtime-only add-ons:
Tree[F32[N], T]Tree[F32[N], T, ...]Tree[F32[N], ..., T]- any custom
Structuresymbol inside the subscript
Summary¶
| Pattern | Meaning | | --------------------------- |
--------------------------------------- | | Tree[LeafType] | Leaf checking
only | | Tree[LeafType, T] | Full structure binding | |
Tree[LeafType, T, ...] | Top-level only | | Tree[LeafType, ..., T] |
Bottom-level only | | Tree[LeafType, T, S] | T = top (one level), S = full
remaining | | Tree[LeafType, T, S, ...] | T = top, S = next, inner unchecked |
| Tree[LeafType, ..., T, S] | S = bottom, T = second-from-bottom |