Compare commits

...

2 Commits

Author SHA1 Message Date
chacha
df50632458 work 2025-09-29 21:14:24 +02:00
cclecle
25d5339946 work 2025-09-28 21:18:35 +02:00
2 changed files with 686 additions and 33 deletions

View File

@@ -1,7 +1,22 @@
"""library's internal tools"""
from collections import ChainMap
from typing import Any, Annotated, get_origin, get_args, Union, Self, Optional, List, Dict, Tuple, Set, FrozenSet, Mapping, Callable
from typing import (
Any,
Annotated,
get_origin,
get_args,
Union,
Self,
Optional,
List,
Dict,
Tuple,
Set,
FrozenSet,
Mapping,
Callable,
)
import typing
from dataclasses import dataclass
from types import UnionType, NoneType
@@ -79,6 +94,13 @@ class AnnotationWalkerCtx:
self.__allowed_annotations: dict[str, Any] = allowed_annotations
self.__ext: dict[Any, ChainMap] = {} # per-trigger namespaces (lazy)
# NEW: edge routing metadata (walker sets; triggers read)
self.__edge_role: str | None = (
None # 'elem' | 'key' | 'val' | 'branch' | 'annotated' | 'arg'
)
self.__edge_token: Any | None = None # index/key/branch-id/etc.
self.__arg_index: int | None = None # which schema arg of the parent this child is
@property
def origin(self) -> Any:
return self.__origin
@@ -99,6 +121,19 @@ class AnnotationWalkerCtx:
def allowed_annotations(self) -> Mapping[str, Any]:
return self.__allowed_annotations
# NEW: read-only edge metadata for routing inside triggers
@property
def edge_role(self) -> str | None:
return self.__edge_role
@property
def edge_token(self) -> Any | None:
return self.__edge_token
@property
def arg_index(self) -> int | None:
return self.__arg_index
def ns(self, owner: Any) -> ChainMap:
"""
A per-trigger overlay namespace that inherits from parent ctx.
@@ -107,11 +142,32 @@ class AnnotationWalkerCtx:
"""
if owner in self.__ext:
return self.__ext[owner]
parent_map = self.__parent.__ext.get(owner) if (self.__parent and hasattr(self.__parent, "_AnnotationWalkerCtx__ext")) else {}
cm = ChainMap({}, parent_map if isinstance(parent_map, ChainMap) else dict(parent_map))
# Determine the parent chain for this owner
parent_chain = ()
if self.__parent is not None and hasattr(self.__parent, "_AnnotationWalkerCtx__ext"):
parent_map = self.__parent._AnnotationWalkerCtx__ext.get(
owner
) # access private attr intentionally
if isinstance(parent_map, ChainMap):
# IMPORTANT: preserve the whole chain (not the ChainMap object as a single mapping)
parent_chain = tuple(parent_map.maps)
elif isinstance(parent_map, dict):
parent_chain = (parent_map,)
else:
parent_chain = ()
# Build a new ChainMap whose first map is this node's local overlay
cm = ChainMap({}, *parent_chain)
self.__ext[owner] = cm
return cm
# INTERNAL (walker only): set edge metadata on a child ctx
def _set_edge(self, *, role: str | None, token: Any | None, arg_index: int | None) -> None:
self.__edge_role = role
self.__edge_token = token
self.__arg_index = arg_index
@dataclass(frozen=True)
class TriggerResult:
@@ -169,6 +225,12 @@ class AnnotationTrigger:
def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
return None
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
return None
def end_trigger(self, ctx: Optional[AnnotationWalkerCtx]) -> None:
pass
class LAMSchemaValidation(AnnotationTrigger):
def init_trigger(self) -> None:
@@ -188,13 +250,17 @@ class LAMSchemaValidation(AnnotationTrigger):
print("process_union")
print(ctx.args)
if (len(ctx.args) != 2) or (type(None) not in list(ctx.args)):
raise UnsupportedFieldType("Union[] is only supported to implement Optional[] (takes 2 parameters, including None)")
raise UnsupportedFieldType(
"Union[] is only supported to implement Optional[] (takes 2 parameters, including None)"
)
return None
def process_dict(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
print("process_dict")
if len(ctx.args) != 2:
raise IncompletelyAnnotatedField(f"Dict Annotation requires 2 inner definitions: {ctx.origin}")
raise IncompletelyAnnotatedField(
f"Dict Annotation requires 2 inner definitions: {ctx.origin}"
)
if not ctx.args[0] in ctx.allowed_types:
raise IncompletelyAnnotatedField(f"Dict Key must be simple builtin: {ctx.origin}")
return None
@@ -220,10 +286,462 @@ class LAMSchemaValidation(AnnotationTrigger):
def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
print("process_allowed")
if ctx.origin is type(None) or ctx.origin is None:
if ctx.parent is None or not (ctx.parent.origin is Union or ctx.parent.origin is UnionType):
raise IncompletelyAnnotatedField(f"None is only accepted with Union, to implement Optional[]")
if ctx.parent is None or not (
ctx.parent.origin is Union or ctx.parent.origin is UnionType
):
raise IncompletelyAnnotatedField(
f"None is only accepted with Union, to implement Optional[]"
)
return None
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
print(f"process_exit: {ctx.origin}")
return None
def end_trigger(self, ctx: AnnotationWalkerCtx):
pass
class DataValidation(AnnotationTrigger):
def __init__(self, value: Any) -> None:
self._root = value
def init_trigger(self) -> None:
self._seeded = False
def _bag(self, ctx: AnnotationWalkerCtx):
bag = ctx.ns(self)
if not self._seeded:
bag["value"] = self._root
bag["path"] = ()
self._seeded = True
return bag
def process_annotated(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_union(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_dict(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_tuple(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_list(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_set(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_unknown(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def end_trigger(self, ctx: AnnotationWalkerCtx):
pass
class SchemaValidationError(TypeError):
def __init__(self, path: tuple[Any, ...], msg: str):
dotted = (
"".join(
f"[{p}]" if isinstance(p, int) else (f".{p}" if path and i else str(p))
for i, p in enumerate(path)
)
or "<root>"
)
super().__init__(f"{dotted}: {msg}")
self.path = path
class HorizontalValidationTrigger(AnnotationTrigger):
def __init__(
self,
value: Any,
*,
strict_bool: bool = False,
collect_all: bool = True,
union_summary_max_per_branch: int = 3, # NEW
show_branch_types: bool = True, # NEW
):
self._root_value = value
self._strict_bool = strict_bool
self._collect_all = collect_all
self._union_summary_max = union_summary_max_per_branch
self._show_branch_types = show_branch_types
self._seeded = False
self._id_counter = 0
# ---------- utilities ----------
# ---------- utilities ----------
def _mk_cand(self, value: Any, path: tuple[Any, ...]):
cid = self._id_counter
self._id_counter += 1
return {"id": cid, "value": value, "path": path}
def _spawn_from(self, bag, parent_cand: dict, value: Any, path: tuple[Any, ...]):
nc = self._mk_cand(value, path)
# Inherit union_roots map from parent (create if missing)
roots = dict(parent_cand.get("union_roots", {}))
nc["union_roots"] = roots
# If we are in a union branch, ensure current scope is stamped
if "__union_branch_id" in bag and "__union_state_ref" in bag:
cur_scope = bag.get("__union_scope_id")
if cur_scope is not None:
roots.setdefault(
cur_scope, parent_cand.get("union_roots", {}).get(cur_scope, parent_cand["id"])
)
return nc
def _bag(self, ctx: AnnotationWalkerCtx):
bag = ctx.ns(self)
local = bag.maps[0]
parent_bag = ctx.parent.ns(self) if ctx.parent is not None else None
# (existing seeding & routing)...
# ------------- NEW: inherit enclosing union flags to all descendants -------------
if parent_bag is not None:
for k in ("__union_state_ref", "__union_branch_id", "__union_scope_id"):
if (k in parent_bag) and (k not in local):
local[k] = parent_bag[k]
# ------------- If immediate child of a Union (branch) -------------
if ctx.parent is not None and ctx.parent.origin is UnionType and ctx.edge_role == "branch":
ubag = ctx.parent.ns(self)
branches = ubag.setdefault("union_branches_state", {})
bid = ctx.edge_token
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
local["errors"] = bstate["errors"]
local["__union_branch_id"] = bid
local["__union_state_ref"] = ubag
# keep current union scope id (from parent union node)
local["__union_scope_id"] = ubag.get("__union_scope_id")
# STAMP union_roots for all candidates at branch entry
stamped = []
cur_scope = local.get("__union_scope_id")
for c in local.get("candidates", []):
c2 = dict(c)
roots = dict(c2.get("union_roots", {}))
if cur_scope is not None:
roots.setdefault(cur_scope, roots.get(cur_scope, c2["id"]))
c2["union_roots"] = roots
stamped.append(c2)
local["candidates"] = stamped
return bag
def _mark_branch_fail(self, bag, cand: dict):
if "__union_branch_id" in bag and "__union_state_ref" in bag:
bid = bag["__union_branch_id"]
uref = bag["__union_state_ref"]
branches = uref.setdefault("union_branches_state", {})
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
root_id = cand.get("union_cid", cand["id"])
bstate["failed_ids"].add(root_id)
def _err(self, bag, cand, msg: str):
"""
Record an error. If we are inside a union branch, attach (cid, scope) so the
outer union can attribute it correctly to its branch/candidate.
"""
e = SchemaValidationError(cand["path"], msg)
# Find nearest union scope (walk ChainMap parents if needed)
def nearest_union_scope(b):
if "__union_scope_id" in b:
return b["__union_scope_id"]
# ChainMap.get() already searches parents; use get() with a sentinel
sentinel = object()
v = b.get("__union_scope_id", sentinel)
return None if v is sentinel else v
in_union_branch = ("__union_branch_id" in bag) and ("__union_state_ref" in bag)
if in_union_branch:
# union-root candidate id for attribution (propagated by spawn_from / branch entry)
root_id = cand.get("union_cid", cand["id"])
scope = nearest_union_scope(bag)
# Branch-local error bucket lives under union state
uref = bag["__union_state_ref"]
bid = bag["__union_branch_id"]
branches = uref.setdefault("union_branches_state", {})
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
bstate["failed_ids"].add(root_id)
bstate["errors"].append({"cid": root_id, "scope": scope, "err": e})
return
# Outside any union branch: add to global errors (and optional fail-fast)
bag["errors"].append(e)
if not self._collect_all:
raise e
def _check_leaf(self, bag, cand: dict, T: type):
v = cand["value"]
if T is type(None):
if v is not None:
self._err(bag, cand, "expected None")
return
if self._strict_bool and T is bool and type(v) is not bool:
self._err(bag, cand, f"expected bool, got {type(v).__name__}")
return
if v is None or not isinstance(v, T):
got = "None" if v is None else type(v).__name__
self._err(bag, cand, f"expected {T.__name__}, got {got}")
# ---------- entry hooks (seed pending; validate leaves; no aggregation here) ----------
def process_allowed(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
T = ctx.origin
for cand in bag.get("candidates", []):
self._check_leaf(bag, cand, T)
def process_list(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
cands = bag.get("candidates", [])
if len(ctx.args) != 1:
for cand in cands:
self._err(bag, cand, "List[T] requires 1 argument")
return
pending = []
for cand in cands:
v = cand["value"]
if not isinstance(v, list):
self._err(bag, cand, f"expected list, got {type(v).__name__}")
continue
base = cand["path"]
for i in range(len(v)):
pending.append(self._spawn_from(bag, cand, v[i], base + (i,)))
bag["pending_elem"] = pending
def process_tuple(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
cands = bag.get("candidates", [])
vararg = len(ctx.args) == 2 and ctx.args[1] is Ellipsis
if vararg:
pending = []
for cand in cands:
v = cand["value"]
if not isinstance(v, tuple):
self._err(bag, cand, f"expected tuple, got {type(v).__name__}")
continue
base = cand["path"]
for i in range(len(v)):
pending.append(self._spawn_from(bag, cand, v[i], base + (i,)))
bag["pending_elem"] = pending
else:
arity = len(ctx.args)
slots: dict[int, list] = {}
for cand in cands:
v = cand["value"]
if not isinstance(v, tuple):
self._err(bag, cand, f"expected tuple, got {type(v).__name__}")
continue
if len(v) != arity:
self._err(bag, cand, f"expected tuple len {arity}, got {len(v)}")
continue
base = cand["path"]
for i, elem in enumerate(v):
slots.setdefault(i, []).append(self._spawn_from(bag, cand, elem, base + (i,)))
for i, group in slots.items():
bag[f"pending_arg_{i}"] = group
def process_set(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
cands = bag.get("candidates", [])
if len(ctx.args) != 1:
for cand in cands:
self._err(bag, cand, "Set[T]/FrozenSet[T] requires 1 argument")
return
pending = []
for cand in cands:
v = cand["value"]
if not isinstance(v, (set, frozenset)):
self._err(bag, cand, f"expected set/frozenset, got {type(v).__name__}")
continue
base = cand["path"]
for e in v:
pending.append(self._spawn_from(bag, cand, e, base + ("<elem>",)))
bag["pending_elem"] = pending
def process_dict(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
cands = bag.get("candidates", [])
if len(ctx.args) != 2:
for cand in cands:
self._err(bag, cand, "Dict[K,V] requires 2 arguments")
return
pkeys, pvals = [], []
for cand in cands:
v = cand["value"]
if not isinstance(v, dict):
self._err(bag, cand, f"expected dict, got {type(v).__name__}")
continue
base = cand["path"]
for k, val in v.items():
pkeys.append(self._spawn_from(bag, cand, k, base + ("<key>",)))
pvals.append(self._spawn_from(bag, cand, val, base + (k,)))
bag["pending_key"] = pkeys
bag["pending_val"] = pvals
def process_annotated(self, ctx: AnnotationWalkerCtx):
# No checks here; inner T will validate routed candidates.
self._bag(ctx)
def process_union(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
cands = bag.get("candidates", [])
# Preserve enclosing union scope id (if any) so inner unions can bubble to it
if "__outer_union_scope_id" not in bag and "__union_scope_id" in bag:
bag["__outer_union_scope_id"] = bag["__union_scope_id"]
bag["union_candidate_ids"] = [c["id"] for c in cands]
bag.setdefault("union_branches_state", {}) # bid -> {"failed_ids": set(), "errors": []}
bag["union_branch_labels"] = [self._pretty_type(a) for a in ctx.args]
# Set THIS union's scope id (distinct from the preserved outer scope)
bag["__union_scope_id"] = id(ctx)
# ---------- exit (aggregation only) ----------
def process_exit(self, ctx: AnnotationWalkerCtx):
if ctx.origin is not UnionType:
return None
bag = ctx.ns(self)
root_errors = bag.get("errors", [])
branches = bag.get("union_branches_state", {})
cand_ids = bag.get("union_candidate_ids", [])
labels = bag.get("union_branch_labels", [])
scope_id = bag.get("__union_scope_id")
outer_scope_id = bag.get("__outer_union_scope_id") # preserved from parent (if any)
# Build per-branch error map for THIS union scope
per_branch_errors: dict[int, dict[int, list[SchemaValidationError]]] = {}
for bid, state in branches.items():
berrs = state.get("errors", [])
m: dict[int, list[SchemaValidationError]] = {}
for item in berrs:
if isinstance(item, dict):
if item.get("scope") != scope_id:
continue
cid = item.get("cid")
err = item.get("err")
if cid is not None and isinstance(err, SchemaValidationError):
m.setdefault(cid, []).append(err)
elif isinstance(item, SchemaValidationError):
for cid in state.get("failed_ids", set()):
m.setdefault(cid, []).append(item)
# de-dupe per cid
for cid, lst in m.items():
seen = set()
uniq = []
for e in lst:
key = (e.path, str(e))
if key not in seen:
seen.add(key)
uniq.append(e)
m[cid] = uniq
per_branch_errors[bid] = m
# Decision: pass if ANY branch has zero scoped errors
for cid in cand_ids:
branch_pass = any(
len(per_branch_errors.get(bid, {}).get(cid, [])) == 0 for bid in branches.keys()
)
if branch_pass:
continue
# --- Bubble to enclosing union if we are inside one (i.e., this union is in a branch) ---
if ("__union_state_ref" in bag) and ("__union_branch_id" in bag):
uref = bag["__union_state_ref"] # enclosing union node bag
up_bid = bag["__union_branch_id"] # which enclosing branch we are in
up_branches = uref.setdefault("union_branches_state", {})
up_state = up_branches.setdefault(up_bid, {"failed_ids": set(), "errors": []})
up_scope = uref.get("__union_scope_id")
# Find local candidate to recover outer root via union_roots
cands_here = bag.get("candidates", [])
lc = next((c for c in cands_here if c["id"] == cid), None)
if lc is not None and outer_scope_id is not None:
outer_root = lc.get("union_roots", {}).get(outer_scope_id, lc["id"])
else:
outer_root = cid
up_state["failed_ids"].add(outer_root)
up_state["errors"].append(
{
"cid": outer_root,
"scope": up_scope,
"err": SchemaValidationError((), "mismatch"),
}
)
# Build readable per-branch summary at THIS union site
cands_here = bag.get("candidates", [])
path = next((c["path"] for c in cands_here if c["id"] == cid), ())
lines = ["no union branch matched; tried:"]
for bid, state in branches.items():
label = labels[bid] if bid < len(labels) else f"branch {bid}"
these = per_branch_errors.get(bid, {}).get(cid, [])
if not these:
lines.append(f" - {label}: mismatch")
continue
N = getattr(self, "_union_summary_max", 3)
shown = these[:N]
lines.append(f" - {label}: {len(these)} issue(s)")
for e in shown:
lines.append(f" - {e}")
if len(these) > N:
lines.append(f" - (+{len(these) - N} more)")
root_errors.append(SchemaValidationError(path, "\n".join(lines)))
return None
def end_trigger(self, ctx: Optional[AnnotationWalkerCtx]):
if ctx is None:
return
errors = ctx.ns(self).get("errors", [])
if errors:
# Raise them together; swap for your preferred error carrier if needed
raise ExceptionGroup("schema validation failed", errors)
# --- pretty type for messages (best-effort) ---
def _pretty_type(self, t: Any) -> str:
origin = get_origin(t) or t
args = get_args(t)
try:
if origin is UnionType or origin is Union:
return " | ".join(self._pretty_type(a) for a in args)
if origin in (list, tuple, set, frozenset, dict, Annotated):
if origin is dict and len(args) == 2:
return f"dict[{self._pretty_type(args[0])}, {self._pretty_type(args[1])}]"
if origin in (list, set, frozenset) and len(args) == 1:
name = "list" if origin is list else ("set" if origin is set else "frozenset")
return f"{name}[{self._pretty_type(args[0])}]"
if origin is tuple:
if len(args) == 2 and args[1] is Ellipsis:
return f"tuple[{self._pretty_type(args[0])}, ...]"
return f"tuple[{', '.join(self._pretty_type(a) for a in args)}]"
if origin is Annotated and args:
return f"Annotated[{self._pretty_type(args[0])}, ...]"
if isinstance(origin, type):
return origin.__name__
return str(t)
except Exception:
return repr(t)
class AnnotationWalker:
DEFAULT_ALLOWED_TYPES = frozenset({str, int, float, complex, bool, bytes, NoneType})
@@ -235,7 +753,7 @@ class AnnotationWalker:
"Dict": Dict,
"Tuple": Tuple,
"Set": Set,
"FrozenSet": FrozenSet,
# "FrozenSet": FrozenSet,
"Annotated": Annotated,
# builtins:
"int": int,
@@ -248,7 +766,7 @@ class AnnotationWalker:
"list": list,
"dict": dict,
"set": set,
"frozenset": frozenset,
# "frozenset": frozenset,
"tuple": tuple,
}
)
@@ -285,19 +803,24 @@ class AnnotationWalker:
self.__ann = eval(ann, {"__builtins__": {}}, self._allowed_annotations)
def run(self) -> TriggerResult:
for trigger in self._triggers:
trigger.init_trigger()
return self._walk(self.__ann, None)
for t in self._triggers:
t.init_trigger()
self.root_ctx = None
self._walk(self.__ann, None)
for t in self._triggers:
t.end_trigger(self.root_ctx)
# --- Helpers ---
def _new_ctx(self, origin, args, layer, parent):
return AnnotationWalkerCtx(origin, args, layer, parent, self._allowed_types, self._allowed_annotations)
return AnnotationWalkerCtx(
origin, args, layer, parent, self._allowed_types, self._allowed_annotations
)
def _apply_triggers(self, method: str, ctx: AnnotationWalkerCtx) -> TriggerResult:
final = TriggerResult.passthrough()
for trig in self._triggers:
res = getattr(trig, method)(ctx)
for t in self._triggers:
res = getattr(t, method)(ctx)
if not res:
continue
if res.restart_with is not None:
@@ -311,33 +834,64 @@ class AnnotationWalker:
)
return final
def _apply_exits(self, ctx: AnnotationWalkerCtx) -> Any | None:
"""
Run exits in order: type-specific (if implemented) then generic `process_exit`.
Only `replace_with` is honored at exit; last `replace_with` wins.
"""
final = None
for t in self._triggers:
res = t.process_exit(ctx)
if res and (res.replace_with is not None):
final = res.replace_with
return final
def _handle_with_triggers(
self,
trigger_name: str,
ctx: AnnotationWalkerCtx,
args_handler: Callable[[AnnotationWalkerCtx], Any] | None = None,
) -> Any:
"""Generic handler: run triggers, maybe recurse into args with a custom handler."""
# ENTER
res = self._apply_triggers(trigger_name, ctx)
if res.restart_with is not None:
return self._walk(res.restart_with, ctx.parent)
if res.replace_with is not None:
return res.replace_with
exit_val = self._apply_exits(ctx)
return exit_val if exit_val is not None else res.replace_with
node_value = None
if not res.skip_children:
if args_handler:
return args_handler(ctx)
return tuple(self._walk(a, ctx) for a in ctx.args)
return None
node_value = args_handler(ctx)
else:
# DEFAULT: descend once per schema arg; mark as positional arg
node_value = tuple(
self._walk_child(a, ctx, arg_index=i, role="arg", token=i)
for i, a in enumerate(ctx.args)
)
# EXIT
exit_val = self._apply_exits(ctx)
return exit_val if exit_val is not None else node_value
def _walk_args_tuple(self, ctx: AnnotationWalkerCtx):
# special Ellipsis case for Tuple
# Tuple[T, ...] (variadic): one schema child, mark role='elem'
if len(ctx.args) == 2 and ctx.args[1] is Ellipsis:
return (self._walk(ctx.args[0], ctx), Ellipsis)
return tuple(self._walk(a, ctx) for a in ctx.args)
return (
self._walk_child(ctx.args[0], ctx, arg_index=0, role="elem", token=None),
Ellipsis,
)
# Fixed tuple: each positional arg gets role='arg'
return tuple(
self._walk_child(a, ctx, arg_index=i, role="arg", token=i)
for i, a in enumerate(ctx.args)
)
# --- Dispatcher ---
def _walk(self, type_: Any, parent_ctx: Optional[AnnotationWalkerCtx]) -> Any:
# For logs only: show the calling layer
print(f"[{parent_ctx.layer if parent_ctx else 0}] walking through: {type_}")
origin = get_origin(type_) or type_
@@ -349,31 +903,123 @@ class AnnotationWalker:
raise RuntimeError("Annotation must be using type(s), not instances")
args = get_args(type_)
layer = 0 if parent_ctx is None else parent_ctx.layer + 1
ctx = self._new_ctx(origin, args, layer, parent_ctx)
# IMPORTANT:
# If caller (_walk_child) already constructed a child ctx with edge metadata,
# reuse that object as *the* ctx for this node instead of allocating a new one.
if (
isinstance(parent_ctx, AnnotationWalkerCtx)
and parent_ctx.origin is origin
and parent_ctx.args == args
):
ctx = parent_ctx
else:
# Root or internal calls that didn't prebuild the ctx
layer = 0 if parent_ctx is None else parent_ctx.layer + 1
ctx = self._new_ctx(origin, args, layer, parent_ctx)
# Remember root ctx for end_trigger()
if ctx.parent is None:
self.root_ctx = ctx
print(origin)
match origin:
case typing.Annotated:
# inner type gets role='annotated'
return self._handle_with_triggers(
"process_annotated", ctx, args_handler=lambda c: self._walk(c.args[0], c) if c.args else None
"process_annotated",
ctx,
args_handler=lambda c: (
self._walk_child(c.args[0], c, arg_index=0, role="annotated", token=None)
if c.args
else None
),
)
case types.UnionType:
return self._handle_with_triggers("process_union", ctx)
# branches get role='branch' and token=branch index
return self._handle_with_triggers(
"process_union",
ctx,
args_handler=lambda c: tuple(
self._walk_child(a, c, arg_index=i, role="branch", token=i)
for i, a in enumerate(c.args)
),
)
case _ if issubclass(origin, dict):
return self._handle_with_triggers("process_dict", ctx)
# arg0=key (role='key'), arg1=value (role='val')
return self._handle_with_triggers(
"process_dict",
ctx,
args_handler=lambda c: (
self._walk_child(c.args[0], c, arg_index=0, role="key", token=None),
self._walk_child(c.args[1], c, arg_index=1, role="val", token=None),
),
)
case _ if issubclass(origin, tuple):
return self._handle_with_triggers("process_tuple", ctx, self._walk_args_tuple)
case _ if issubclass(origin, list):
return self._handle_with_triggers("process_list", ctx)
# single child T with role='elem'
return self._handle_with_triggers(
"process_list",
ctx,
args_handler=lambda c: (
self._walk_child(c.args[0], c, arg_index=0, role="elem", token=None)
if c.args
else None
),
)
case _ if issubclass(origin, set):
return self._handle_with_triggers("process_set", ctx)
# single child T with role='elem'
return self._handle_with_triggers(
"process_set",
ctx,
args_handler=lambda c: (
self._walk_child(c.args[0], c, arg_index=0, role="elem", token=None)
if c.args
else None
),
)
case _ if origin in self._allowed_types:
return self._handle_with_triggers("process_allowed", ctx)
case _:
res = self._apply_triggers("process_unknown", ctx)
if res.restart_with is not None:
return self._walk(res.restart_with, ctx.parent)
if res.replace_with is not None:
return res.replace_with
raise UnsupportedFieldType(f"Not supported Field: {ctx.origin}, " f"Supported list: {self._allowed_types}")
raise UnsupportedFieldType(
f"Not supported Field: {ctx.origin}, Supported list: {self._allowed_types}"
)
def _walk_child(
self,
type_expr: Any,
parent_ctx: AnnotationWalkerCtx,
*,
arg_index: int,
role: str | None,
token: Any | None,
) -> Any:
origin = get_origin(type_expr) or type_expr
if origin is None:
origin = NoneType
if origin is Union:
origin = UnionType
if not isinstance(origin, type):
raise RuntimeError("Annotation must be using type(s), not instances")
args = get_args(type_expr)
child = self._new_ctx(origin, args, parent_ctx.layer + 1, parent_ctx)
# stamp routing metadata for triggers
child._set_edge(role=role, token=token, arg_index=arg_index)
# IMPORTANT: recurse with the CHILD as the parent_ctx for the next step,
# so the walker uses this child ctx (with edge metadata & incremented layer).
return self._walk(type_expr, child)

View File

@@ -22,11 +22,18 @@ testdir_path = Path(__file__).parent.resolve()
chdir(testdir_path.parent.resolve())
class ElementTest(unittest.TestCase):
class AnnotationsWalkerTest(unittest.TestCase):
def setUp(self):
print("\n->", unittest.TestCase.id(self))
def test_element_simple(self):
def test_validate(self):
ann = dict[int, list[int]] | dict[int, list[int | str]]
val = {1: [2], 2: ["a", [1]]}
res = dm.tools.AnnotationWalker(ann, (dm.tools.HorizontalValidationTrigger(val),))
res.run()
def test_simple(self):
print(isinstance(None, type(None)))
print("\n== From OBJs ==")
res = dm.tools.AnnotationWalker(Annotated[Optional[dict[int, list[str]]], "comment"], (dm.tools.LAMSchemaValidation(),))