|
|
|
|
@@ -1,7 +1,22 @@
|
|
|
|
|
"""library's internal tools"""
|
|
|
|
|
|
|
|
|
|
from collections import ChainMap
|
|
|
|
|
from typing import Any, Annotated, get_origin, get_args, Union, Self, Optional, List, Dict, Tuple, Set, FrozenSet, Mapping, Callable
|
|
|
|
|
from typing import (
|
|
|
|
|
Any,
|
|
|
|
|
Annotated,
|
|
|
|
|
get_origin,
|
|
|
|
|
get_args,
|
|
|
|
|
Union,
|
|
|
|
|
Self,
|
|
|
|
|
Optional,
|
|
|
|
|
List,
|
|
|
|
|
Dict,
|
|
|
|
|
Tuple,
|
|
|
|
|
Set,
|
|
|
|
|
FrozenSet,
|
|
|
|
|
Mapping,
|
|
|
|
|
Callable,
|
|
|
|
|
)
|
|
|
|
|
import typing
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from types import UnionType, NoneType
|
|
|
|
|
@@ -79,6 +94,13 @@ class AnnotationWalkerCtx:
|
|
|
|
|
self.__allowed_annotations: dict[str, Any] = allowed_annotations
|
|
|
|
|
self.__ext: dict[Any, ChainMap] = {} # per-trigger namespaces (lazy)
|
|
|
|
|
|
|
|
|
|
# NEW: edge routing metadata (walker sets; triggers read)
|
|
|
|
|
self.__edge_role: str | None = (
|
|
|
|
|
None # 'elem' | 'key' | 'val' | 'branch' | 'annotated' | 'arg'
|
|
|
|
|
)
|
|
|
|
|
self.__edge_token: Any | None = None # index/key/branch-id/etc.
|
|
|
|
|
self.__arg_index: int | None = None # which schema arg of the parent this child is
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def origin(self) -> Any:
|
|
|
|
|
return self.__origin
|
|
|
|
|
@@ -99,6 +121,19 @@ class AnnotationWalkerCtx:
|
|
|
|
|
def allowed_annotations(self) -> Mapping[str, Any]:
|
|
|
|
|
return self.__allowed_annotations
|
|
|
|
|
|
|
|
|
|
# NEW: read-only edge metadata for routing inside triggers
|
|
|
|
|
@property
|
|
|
|
|
def edge_role(self) -> str | None:
|
|
|
|
|
return self.__edge_role
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def edge_token(self) -> Any | None:
|
|
|
|
|
return self.__edge_token
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def arg_index(self) -> int | None:
|
|
|
|
|
return self.__arg_index
|
|
|
|
|
|
|
|
|
|
def ns(self, owner: Any) -> ChainMap:
|
|
|
|
|
"""
|
|
|
|
|
A per-trigger overlay namespace that inherits from parent ctx.
|
|
|
|
|
@@ -107,11 +142,32 @@ class AnnotationWalkerCtx:
|
|
|
|
|
"""
|
|
|
|
|
if owner in self.__ext:
|
|
|
|
|
return self.__ext[owner]
|
|
|
|
|
parent_map = self.__parent.__ext.get(owner) if (self.__parent and hasattr(self.__parent, "_AnnotationWalkerCtx__ext")) else {}
|
|
|
|
|
cm = ChainMap({}, parent_map if isinstance(parent_map, ChainMap) else dict(parent_map))
|
|
|
|
|
|
|
|
|
|
# Determine the parent chain for this owner
|
|
|
|
|
parent_chain = ()
|
|
|
|
|
if self.__parent is not None and hasattr(self.__parent, "_AnnotationWalkerCtx__ext"):
|
|
|
|
|
parent_map = self.__parent._AnnotationWalkerCtx__ext.get(
|
|
|
|
|
owner
|
|
|
|
|
) # access private attr intentionally
|
|
|
|
|
if isinstance(parent_map, ChainMap):
|
|
|
|
|
# IMPORTANT: preserve the whole chain (not the ChainMap object as a single mapping)
|
|
|
|
|
parent_chain = tuple(parent_map.maps)
|
|
|
|
|
elif isinstance(parent_map, dict):
|
|
|
|
|
parent_chain = (parent_map,)
|
|
|
|
|
else:
|
|
|
|
|
parent_chain = ()
|
|
|
|
|
|
|
|
|
|
# Build a new ChainMap whose first map is this node's local overlay
|
|
|
|
|
cm = ChainMap({}, *parent_chain)
|
|
|
|
|
self.__ext[owner] = cm
|
|
|
|
|
return cm
|
|
|
|
|
|
|
|
|
|
# INTERNAL (walker only): set edge metadata on a child ctx
|
|
|
|
|
def _set_edge(self, *, role: str | None, token: Any | None, arg_index: int | None) -> None:
|
|
|
|
|
self.__edge_role = role
|
|
|
|
|
self.__edge_token = token
|
|
|
|
|
self.__arg_index = arg_index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class TriggerResult:
|
|
|
|
|
@@ -169,6 +225,12 @@ class AnnotationTrigger:
|
|
|
|
|
def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def end_trigger(self, ctx: Optional[AnnotationWalkerCtx]) -> None:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LAMSchemaValidation(AnnotationTrigger):
|
|
|
|
|
def init_trigger(self) -> None:
|
|
|
|
|
@@ -188,13 +250,17 @@ class LAMSchemaValidation(AnnotationTrigger):
|
|
|
|
|
print("process_union")
|
|
|
|
|
print(ctx.args)
|
|
|
|
|
if (len(ctx.args) != 2) or (type(None) not in list(ctx.args)):
|
|
|
|
|
raise UnsupportedFieldType("Union[] is only supported to implement Optional[] (takes 2 parameters, including None)")
|
|
|
|
|
raise UnsupportedFieldType(
|
|
|
|
|
"Union[] is only supported to implement Optional[] (takes 2 parameters, including None)"
|
|
|
|
|
)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def process_dict(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
print("process_dict")
|
|
|
|
|
if len(ctx.args) != 2:
|
|
|
|
|
raise IncompletelyAnnotatedField(f"Dict Annotation requires 2 inner definitions: {ctx.origin}")
|
|
|
|
|
raise IncompletelyAnnotatedField(
|
|
|
|
|
f"Dict Annotation requires 2 inner definitions: {ctx.origin}"
|
|
|
|
|
)
|
|
|
|
|
if not ctx.args[0] in ctx.allowed_types:
|
|
|
|
|
raise IncompletelyAnnotatedField(f"Dict Key must be simple builtin: {ctx.origin}")
|
|
|
|
|
return None
|
|
|
|
|
@@ -220,10 +286,462 @@ class LAMSchemaValidation(AnnotationTrigger):
|
|
|
|
|
def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
print("process_allowed")
|
|
|
|
|
if ctx.origin is type(None) or ctx.origin is None:
|
|
|
|
|
if ctx.parent is None or not (ctx.parent.origin is Union or ctx.parent.origin is UnionType):
|
|
|
|
|
raise IncompletelyAnnotatedField(f"None is only accepted with Union, to implement Optional[]")
|
|
|
|
|
if ctx.parent is None or not (
|
|
|
|
|
ctx.parent.origin is Union or ctx.parent.origin is UnionType
|
|
|
|
|
):
|
|
|
|
|
raise IncompletelyAnnotatedField(
|
|
|
|
|
f"None is only accepted with Union, to implement Optional[]"
|
|
|
|
|
)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
print(f"process_exit: {ctx.origin}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def end_trigger(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataValidation(AnnotationTrigger):
|
|
|
|
|
def __init__(self, value: Any) -> None:
|
|
|
|
|
self._root = value
|
|
|
|
|
|
|
|
|
|
def init_trigger(self) -> None:
|
|
|
|
|
self._seeded = False
|
|
|
|
|
|
|
|
|
|
def _bag(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
bag = ctx.ns(self)
|
|
|
|
|
if not self._seeded:
|
|
|
|
|
bag["value"] = self._root
|
|
|
|
|
bag["path"] = ()
|
|
|
|
|
self._seeded = True
|
|
|
|
|
return bag
|
|
|
|
|
|
|
|
|
|
def process_annotated(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
b = self._bag(ctx)
|
|
|
|
|
|
|
|
|
|
def process_union(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
b = self._bag(ctx)
|
|
|
|
|
|
|
|
|
|
def process_dict(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
b = self._bag(ctx)
|
|
|
|
|
|
|
|
|
|
def process_tuple(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
b = self._bag(ctx)
|
|
|
|
|
|
|
|
|
|
def process_list(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
b = self._bag(ctx)
|
|
|
|
|
|
|
|
|
|
def process_set(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
b = self._bag(ctx)
|
|
|
|
|
|
|
|
|
|
def process_unknown(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
b = self._bag(ctx)
|
|
|
|
|
|
|
|
|
|
def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
b = self._bag(ctx)
|
|
|
|
|
|
|
|
|
|
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
|
|
|
|
b = self._bag(ctx)
|
|
|
|
|
|
|
|
|
|
def end_trigger(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SchemaValidationError(TypeError):
|
|
|
|
|
def __init__(self, path: tuple[Any, ...], msg: str):
|
|
|
|
|
dotted = (
|
|
|
|
|
"".join(
|
|
|
|
|
f"[{p}]" if isinstance(p, int) else (f".{p}" if path and i else str(p))
|
|
|
|
|
for i, p in enumerate(path)
|
|
|
|
|
)
|
|
|
|
|
or "<root>"
|
|
|
|
|
)
|
|
|
|
|
super().__init__(f"{dotted}: {msg}")
|
|
|
|
|
self.path = path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HorizontalValidationTrigger(AnnotationTrigger):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
value: Any,
|
|
|
|
|
*,
|
|
|
|
|
strict_bool: bool = False,
|
|
|
|
|
collect_all: bool = True,
|
|
|
|
|
union_summary_max_per_branch: int = 3, # NEW
|
|
|
|
|
show_branch_types: bool = True, # NEW
|
|
|
|
|
):
|
|
|
|
|
self._root_value = value
|
|
|
|
|
self._strict_bool = strict_bool
|
|
|
|
|
self._collect_all = collect_all
|
|
|
|
|
self._union_summary_max = union_summary_max_per_branch
|
|
|
|
|
self._show_branch_types = show_branch_types
|
|
|
|
|
self._seeded = False
|
|
|
|
|
self._id_counter = 0
|
|
|
|
|
|
|
|
|
|
# ---------- utilities ----------
|
|
|
|
|
|
|
|
|
|
# ---------- utilities ----------
|
|
|
|
|
|
|
|
|
|
def _mk_cand(self, value: Any, path: tuple[Any, ...]):
|
|
|
|
|
cid = self._id_counter
|
|
|
|
|
self._id_counter += 1
|
|
|
|
|
return {"id": cid, "value": value, "path": path}
|
|
|
|
|
|
|
|
|
|
def _spawn_from(self, bag, parent_cand: dict, value: Any, path: tuple[Any, ...]):
|
|
|
|
|
nc = self._mk_cand(value, path)
|
|
|
|
|
# Inherit union_roots map from parent (create if missing)
|
|
|
|
|
roots = dict(parent_cand.get("union_roots", {}))
|
|
|
|
|
nc["union_roots"] = roots
|
|
|
|
|
|
|
|
|
|
# If we are in a union branch, ensure current scope is stamped
|
|
|
|
|
if "__union_branch_id" in bag and "__union_state_ref" in bag:
|
|
|
|
|
cur_scope = bag.get("__union_scope_id")
|
|
|
|
|
if cur_scope is not None:
|
|
|
|
|
roots.setdefault(
|
|
|
|
|
cur_scope, parent_cand.get("union_roots", {}).get(cur_scope, parent_cand["id"])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return nc
|
|
|
|
|
|
|
|
|
|
def _bag(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
bag = ctx.ns(self)
|
|
|
|
|
local = bag.maps[0]
|
|
|
|
|
parent_bag = ctx.parent.ns(self) if ctx.parent is not None else None
|
|
|
|
|
|
|
|
|
|
# (existing seeding & routing)...
|
|
|
|
|
|
|
|
|
|
# ------------- NEW: inherit enclosing union flags to all descendants -------------
|
|
|
|
|
if parent_bag is not None:
|
|
|
|
|
for k in ("__union_state_ref", "__union_branch_id", "__union_scope_id"):
|
|
|
|
|
if (k in parent_bag) and (k not in local):
|
|
|
|
|
local[k] = parent_bag[k]
|
|
|
|
|
|
|
|
|
|
# ------------- If immediate child of a Union (branch) -------------
|
|
|
|
|
if ctx.parent is not None and ctx.parent.origin is UnionType and ctx.edge_role == "branch":
|
|
|
|
|
ubag = ctx.parent.ns(self)
|
|
|
|
|
branches = ubag.setdefault("union_branches_state", {})
|
|
|
|
|
bid = ctx.edge_token
|
|
|
|
|
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
|
|
|
|
|
local["errors"] = bstate["errors"]
|
|
|
|
|
local["__union_branch_id"] = bid
|
|
|
|
|
local["__union_state_ref"] = ubag
|
|
|
|
|
# keep current union scope id (from parent union node)
|
|
|
|
|
local["__union_scope_id"] = ubag.get("__union_scope_id")
|
|
|
|
|
|
|
|
|
|
# STAMP union_roots for all candidates at branch entry
|
|
|
|
|
stamped = []
|
|
|
|
|
cur_scope = local.get("__union_scope_id")
|
|
|
|
|
for c in local.get("candidates", []):
|
|
|
|
|
c2 = dict(c)
|
|
|
|
|
roots = dict(c2.get("union_roots", {}))
|
|
|
|
|
if cur_scope is not None:
|
|
|
|
|
roots.setdefault(cur_scope, roots.get(cur_scope, c2["id"]))
|
|
|
|
|
c2["union_roots"] = roots
|
|
|
|
|
stamped.append(c2)
|
|
|
|
|
local["candidates"] = stamped
|
|
|
|
|
|
|
|
|
|
return bag
|
|
|
|
|
|
|
|
|
|
def _mark_branch_fail(self, bag, cand: dict):
|
|
|
|
|
if "__union_branch_id" in bag and "__union_state_ref" in bag:
|
|
|
|
|
bid = bag["__union_branch_id"]
|
|
|
|
|
uref = bag["__union_state_ref"]
|
|
|
|
|
branches = uref.setdefault("union_branches_state", {})
|
|
|
|
|
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
|
|
|
|
|
root_id = cand.get("union_cid", cand["id"])
|
|
|
|
|
bstate["failed_ids"].add(root_id)
|
|
|
|
|
|
|
|
|
|
def _err(self, bag, cand, msg: str):
|
|
|
|
|
"""
|
|
|
|
|
Record an error. If we are inside a union branch, attach (cid, scope) so the
|
|
|
|
|
outer union can attribute it correctly to its branch/candidate.
|
|
|
|
|
"""
|
|
|
|
|
e = SchemaValidationError(cand["path"], msg)
|
|
|
|
|
|
|
|
|
|
# Find nearest union scope (walk ChainMap parents if needed)
|
|
|
|
|
def nearest_union_scope(b):
|
|
|
|
|
if "__union_scope_id" in b:
|
|
|
|
|
return b["__union_scope_id"]
|
|
|
|
|
# ChainMap.get() already searches parents; use get() with a sentinel
|
|
|
|
|
sentinel = object()
|
|
|
|
|
v = b.get("__union_scope_id", sentinel)
|
|
|
|
|
return None if v is sentinel else v
|
|
|
|
|
|
|
|
|
|
in_union_branch = ("__union_branch_id" in bag) and ("__union_state_ref" in bag)
|
|
|
|
|
if in_union_branch:
|
|
|
|
|
# union-root candidate id for attribution (propagated by spawn_from / branch entry)
|
|
|
|
|
root_id = cand.get("union_cid", cand["id"])
|
|
|
|
|
scope = nearest_union_scope(bag)
|
|
|
|
|
# Branch-local error bucket lives under union state
|
|
|
|
|
uref = bag["__union_state_ref"]
|
|
|
|
|
bid = bag["__union_branch_id"]
|
|
|
|
|
branches = uref.setdefault("union_branches_state", {})
|
|
|
|
|
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
|
|
|
|
|
bstate["failed_ids"].add(root_id)
|
|
|
|
|
bstate["errors"].append({"cid": root_id, "scope": scope, "err": e})
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Outside any union branch: add to global errors (and optional fail-fast)
|
|
|
|
|
bag["errors"].append(e)
|
|
|
|
|
if not self._collect_all:
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
def _check_leaf(self, bag, cand: dict, T: type):
|
|
|
|
|
v = cand["value"]
|
|
|
|
|
if T is type(None):
|
|
|
|
|
if v is not None:
|
|
|
|
|
self._err(bag, cand, "expected None")
|
|
|
|
|
return
|
|
|
|
|
if self._strict_bool and T is bool and type(v) is not bool:
|
|
|
|
|
self._err(bag, cand, f"expected bool, got {type(v).__name__}")
|
|
|
|
|
return
|
|
|
|
|
if v is None or not isinstance(v, T):
|
|
|
|
|
got = "None" if v is None else type(v).__name__
|
|
|
|
|
self._err(bag, cand, f"expected {T.__name__}, got {got}")
|
|
|
|
|
|
|
|
|
|
# ---------- entry hooks (seed pending; validate leaves; no aggregation here) ----------
|
|
|
|
|
|
|
|
|
|
def process_allowed(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
bag = self._bag(ctx)
|
|
|
|
|
T = ctx.origin
|
|
|
|
|
for cand in bag.get("candidates", []):
|
|
|
|
|
self._check_leaf(bag, cand, T)
|
|
|
|
|
|
|
|
|
|
def process_list(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
bag = self._bag(ctx)
|
|
|
|
|
cands = bag.get("candidates", [])
|
|
|
|
|
if len(ctx.args) != 1:
|
|
|
|
|
for cand in cands:
|
|
|
|
|
self._err(bag, cand, "List[T] requires 1 argument")
|
|
|
|
|
return
|
|
|
|
|
pending = []
|
|
|
|
|
for cand in cands:
|
|
|
|
|
v = cand["value"]
|
|
|
|
|
if not isinstance(v, list):
|
|
|
|
|
self._err(bag, cand, f"expected list, got {type(v).__name__}")
|
|
|
|
|
continue
|
|
|
|
|
base = cand["path"]
|
|
|
|
|
for i in range(len(v)):
|
|
|
|
|
pending.append(self._spawn_from(bag, cand, v[i], base + (i,)))
|
|
|
|
|
bag["pending_elem"] = pending
|
|
|
|
|
|
|
|
|
|
def process_tuple(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
bag = self._bag(ctx)
|
|
|
|
|
cands = bag.get("candidates", [])
|
|
|
|
|
vararg = len(ctx.args) == 2 and ctx.args[1] is Ellipsis
|
|
|
|
|
if vararg:
|
|
|
|
|
pending = []
|
|
|
|
|
for cand in cands:
|
|
|
|
|
v = cand["value"]
|
|
|
|
|
if not isinstance(v, tuple):
|
|
|
|
|
self._err(bag, cand, f"expected tuple, got {type(v).__name__}")
|
|
|
|
|
continue
|
|
|
|
|
base = cand["path"]
|
|
|
|
|
for i in range(len(v)):
|
|
|
|
|
pending.append(self._spawn_from(bag, cand, v[i], base + (i,)))
|
|
|
|
|
bag["pending_elem"] = pending
|
|
|
|
|
else:
|
|
|
|
|
arity = len(ctx.args)
|
|
|
|
|
slots: dict[int, list] = {}
|
|
|
|
|
for cand in cands:
|
|
|
|
|
v = cand["value"]
|
|
|
|
|
if not isinstance(v, tuple):
|
|
|
|
|
self._err(bag, cand, f"expected tuple, got {type(v).__name__}")
|
|
|
|
|
continue
|
|
|
|
|
if len(v) != arity:
|
|
|
|
|
self._err(bag, cand, f"expected tuple len {arity}, got {len(v)}")
|
|
|
|
|
continue
|
|
|
|
|
base = cand["path"]
|
|
|
|
|
for i, elem in enumerate(v):
|
|
|
|
|
slots.setdefault(i, []).append(self._spawn_from(bag, cand, elem, base + (i,)))
|
|
|
|
|
for i, group in slots.items():
|
|
|
|
|
bag[f"pending_arg_{i}"] = group
|
|
|
|
|
|
|
|
|
|
def process_set(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
bag = self._bag(ctx)
|
|
|
|
|
cands = bag.get("candidates", [])
|
|
|
|
|
if len(ctx.args) != 1:
|
|
|
|
|
for cand in cands:
|
|
|
|
|
self._err(bag, cand, "Set[T]/FrozenSet[T] requires 1 argument")
|
|
|
|
|
return
|
|
|
|
|
pending = []
|
|
|
|
|
for cand in cands:
|
|
|
|
|
v = cand["value"]
|
|
|
|
|
if not isinstance(v, (set, frozenset)):
|
|
|
|
|
self._err(bag, cand, f"expected set/frozenset, got {type(v).__name__}")
|
|
|
|
|
continue
|
|
|
|
|
base = cand["path"]
|
|
|
|
|
for e in v:
|
|
|
|
|
pending.append(self._spawn_from(bag, cand, e, base + ("<elem>",)))
|
|
|
|
|
bag["pending_elem"] = pending
|
|
|
|
|
|
|
|
|
|
def process_dict(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
bag = self._bag(ctx)
|
|
|
|
|
cands = bag.get("candidates", [])
|
|
|
|
|
if len(ctx.args) != 2:
|
|
|
|
|
for cand in cands:
|
|
|
|
|
self._err(bag, cand, "Dict[K,V] requires 2 arguments")
|
|
|
|
|
return
|
|
|
|
|
pkeys, pvals = [], []
|
|
|
|
|
for cand in cands:
|
|
|
|
|
v = cand["value"]
|
|
|
|
|
if not isinstance(v, dict):
|
|
|
|
|
self._err(bag, cand, f"expected dict, got {type(v).__name__}")
|
|
|
|
|
continue
|
|
|
|
|
base = cand["path"]
|
|
|
|
|
for k, val in v.items():
|
|
|
|
|
pkeys.append(self._spawn_from(bag, cand, k, base + ("<key>",)))
|
|
|
|
|
pvals.append(self._spawn_from(bag, cand, val, base + (k,)))
|
|
|
|
|
bag["pending_key"] = pkeys
|
|
|
|
|
bag["pending_val"] = pvals
|
|
|
|
|
|
|
|
|
|
def process_annotated(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
# No checks here; inner T will validate routed candidates.
|
|
|
|
|
self._bag(ctx)
|
|
|
|
|
|
|
|
|
|
def process_union(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
bag = self._bag(ctx)
|
|
|
|
|
cands = bag.get("candidates", [])
|
|
|
|
|
# Preserve enclosing union scope id (if any) so inner unions can bubble to it
|
|
|
|
|
if "__outer_union_scope_id" not in bag and "__union_scope_id" in bag:
|
|
|
|
|
bag["__outer_union_scope_id"] = bag["__union_scope_id"]
|
|
|
|
|
|
|
|
|
|
bag["union_candidate_ids"] = [c["id"] for c in cands]
|
|
|
|
|
bag.setdefault("union_branches_state", {}) # bid -> {"failed_ids": set(), "errors": []}
|
|
|
|
|
bag["union_branch_labels"] = [self._pretty_type(a) for a in ctx.args]
|
|
|
|
|
# Set THIS union's scope id (distinct from the preserved outer scope)
|
|
|
|
|
bag["__union_scope_id"] = id(ctx)
|
|
|
|
|
|
|
|
|
|
# ---------- exit (aggregation only) ----------
|
|
|
|
|
|
|
|
|
|
def process_exit(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
if ctx.origin is not UnionType:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
bag = ctx.ns(self)
|
|
|
|
|
root_errors = bag.get("errors", [])
|
|
|
|
|
branches = bag.get("union_branches_state", {})
|
|
|
|
|
cand_ids = bag.get("union_candidate_ids", [])
|
|
|
|
|
labels = bag.get("union_branch_labels", [])
|
|
|
|
|
scope_id = bag.get("__union_scope_id")
|
|
|
|
|
outer_scope_id = bag.get("__outer_union_scope_id") # preserved from parent (if any)
|
|
|
|
|
|
|
|
|
|
# Build per-branch error map for THIS union scope
|
|
|
|
|
per_branch_errors: dict[int, dict[int, list[SchemaValidationError]]] = {}
|
|
|
|
|
for bid, state in branches.items():
|
|
|
|
|
berrs = state.get("errors", [])
|
|
|
|
|
m: dict[int, list[SchemaValidationError]] = {}
|
|
|
|
|
for item in berrs:
|
|
|
|
|
if isinstance(item, dict):
|
|
|
|
|
if item.get("scope") != scope_id:
|
|
|
|
|
continue
|
|
|
|
|
cid = item.get("cid")
|
|
|
|
|
err = item.get("err")
|
|
|
|
|
if cid is not None and isinstance(err, SchemaValidationError):
|
|
|
|
|
m.setdefault(cid, []).append(err)
|
|
|
|
|
elif isinstance(item, SchemaValidationError):
|
|
|
|
|
for cid in state.get("failed_ids", set()):
|
|
|
|
|
m.setdefault(cid, []).append(item)
|
|
|
|
|
# de-dupe per cid
|
|
|
|
|
for cid, lst in m.items():
|
|
|
|
|
seen = set()
|
|
|
|
|
uniq = []
|
|
|
|
|
for e in lst:
|
|
|
|
|
key = (e.path, str(e))
|
|
|
|
|
if key not in seen:
|
|
|
|
|
seen.add(key)
|
|
|
|
|
uniq.append(e)
|
|
|
|
|
m[cid] = uniq
|
|
|
|
|
per_branch_errors[bid] = m
|
|
|
|
|
|
|
|
|
|
# Decision: pass if ANY branch has zero scoped errors
|
|
|
|
|
for cid in cand_ids:
|
|
|
|
|
branch_pass = any(
|
|
|
|
|
len(per_branch_errors.get(bid, {}).get(cid, [])) == 0 for bid in branches.keys()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if branch_pass:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# --- Bubble to enclosing union if we are inside one (i.e., this union is in a branch) ---
|
|
|
|
|
if ("__union_state_ref" in bag) and ("__union_branch_id" in bag):
|
|
|
|
|
uref = bag["__union_state_ref"] # enclosing union node bag
|
|
|
|
|
up_bid = bag["__union_branch_id"] # which enclosing branch we are in
|
|
|
|
|
up_branches = uref.setdefault("union_branches_state", {})
|
|
|
|
|
up_state = up_branches.setdefault(up_bid, {"failed_ids": set(), "errors": []})
|
|
|
|
|
up_scope = uref.get("__union_scope_id")
|
|
|
|
|
|
|
|
|
|
# Find local candidate to recover outer root via union_roots
|
|
|
|
|
cands_here = bag.get("candidates", [])
|
|
|
|
|
lc = next((c for c in cands_here if c["id"] == cid), None)
|
|
|
|
|
if lc is not None and outer_scope_id is not None:
|
|
|
|
|
outer_root = lc.get("union_roots", {}).get(outer_scope_id, lc["id"])
|
|
|
|
|
else:
|
|
|
|
|
outer_root = cid
|
|
|
|
|
|
|
|
|
|
up_state["failed_ids"].add(outer_root)
|
|
|
|
|
up_state["errors"].append(
|
|
|
|
|
{
|
|
|
|
|
"cid": outer_root,
|
|
|
|
|
"scope": up_scope,
|
|
|
|
|
"err": SchemaValidationError((), "mismatch"),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Build readable per-branch summary at THIS union site
|
|
|
|
|
cands_here = bag.get("candidates", [])
|
|
|
|
|
path = next((c["path"] for c in cands_here if c["id"] == cid), ())
|
|
|
|
|
lines = ["no union branch matched; tried:"]
|
|
|
|
|
for bid, state in branches.items():
|
|
|
|
|
label = labels[bid] if bid < len(labels) else f"branch {bid}"
|
|
|
|
|
these = per_branch_errors.get(bid, {}).get(cid, [])
|
|
|
|
|
if not these:
|
|
|
|
|
lines.append(f" - {label}: mismatch")
|
|
|
|
|
continue
|
|
|
|
|
N = getattr(self, "_union_summary_max", 3)
|
|
|
|
|
shown = these[:N]
|
|
|
|
|
lines.append(f" - {label}: {len(these)} issue(s)")
|
|
|
|
|
for e in shown:
|
|
|
|
|
lines.append(f" - {e}")
|
|
|
|
|
if len(these) > N:
|
|
|
|
|
lines.append(f" - (+{len(these) - N} more)")
|
|
|
|
|
root_errors.append(SchemaValidationError(path, "\n".join(lines)))
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def end_trigger(self, ctx: Optional[AnnotationWalkerCtx]):
|
|
|
|
|
if ctx is None:
|
|
|
|
|
return
|
|
|
|
|
errors = ctx.ns(self).get("errors", [])
|
|
|
|
|
if errors:
|
|
|
|
|
# Raise them together; swap for your preferred error carrier if needed
|
|
|
|
|
raise ExceptionGroup("schema validation failed", errors)
|
|
|
|
|
|
|
|
|
|
# --- pretty type for messages (best-effort) ---
|
|
|
|
|
def _pretty_type(self, t: Any) -> str:
|
|
|
|
|
origin = get_origin(t) or t
|
|
|
|
|
args = get_args(t)
|
|
|
|
|
try:
|
|
|
|
|
if origin is UnionType or origin is Union:
|
|
|
|
|
return " | ".join(self._pretty_type(a) for a in args)
|
|
|
|
|
if origin in (list, tuple, set, frozenset, dict, Annotated):
|
|
|
|
|
if origin is dict and len(args) == 2:
|
|
|
|
|
return f"dict[{self._pretty_type(args[0])}, {self._pretty_type(args[1])}]"
|
|
|
|
|
if origin in (list, set, frozenset) and len(args) == 1:
|
|
|
|
|
name = "list" if origin is list else ("set" if origin is set else "frozenset")
|
|
|
|
|
return f"{name}[{self._pretty_type(args[0])}]"
|
|
|
|
|
if origin is tuple:
|
|
|
|
|
if len(args) == 2 and args[1] is Ellipsis:
|
|
|
|
|
return f"tuple[{self._pretty_type(args[0])}, ...]"
|
|
|
|
|
return f"tuple[{', '.join(self._pretty_type(a) for a in args)}]"
|
|
|
|
|
if origin is Annotated and args:
|
|
|
|
|
return f"Annotated[{self._pretty_type(args[0])}, ...]"
|
|
|
|
|
if isinstance(origin, type):
|
|
|
|
|
return origin.__name__
|
|
|
|
|
return str(t)
|
|
|
|
|
except Exception:
|
|
|
|
|
return repr(t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AnnotationWalker:
|
|
|
|
|
DEFAULT_ALLOWED_TYPES = frozenset({str, int, float, complex, bool, bytes, NoneType})
|
|
|
|
|
@@ -235,7 +753,7 @@ class AnnotationWalker:
|
|
|
|
|
"Dict": Dict,
|
|
|
|
|
"Tuple": Tuple,
|
|
|
|
|
"Set": Set,
|
|
|
|
|
"FrozenSet": FrozenSet,
|
|
|
|
|
# "FrozenSet": FrozenSet,
|
|
|
|
|
"Annotated": Annotated,
|
|
|
|
|
# builtins:
|
|
|
|
|
"int": int,
|
|
|
|
|
@@ -248,7 +766,7 @@ class AnnotationWalker:
|
|
|
|
|
"list": list,
|
|
|
|
|
"dict": dict,
|
|
|
|
|
"set": set,
|
|
|
|
|
"frozenset": frozenset,
|
|
|
|
|
# "frozenset": frozenset,
|
|
|
|
|
"tuple": tuple,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
@@ -285,19 +803,24 @@ class AnnotationWalker:
|
|
|
|
|
self.__ann = eval(ann, {"__builtins__": {}}, self._allowed_annotations)
|
|
|
|
|
|
|
|
|
|
def run(self) -> TriggerResult:
|
|
|
|
|
for trigger in self._triggers:
|
|
|
|
|
trigger.init_trigger()
|
|
|
|
|
return self._walk(self.__ann, None)
|
|
|
|
|
for t in self._triggers:
|
|
|
|
|
t.init_trigger()
|
|
|
|
|
self.root_ctx = None
|
|
|
|
|
self._walk(self.__ann, None)
|
|
|
|
|
for t in self._triggers:
|
|
|
|
|
t.end_trigger(self.root_ctx)
|
|
|
|
|
|
|
|
|
|
# --- Helpers ---
|
|
|
|
|
|
|
|
|
|
def _new_ctx(self, origin, args, layer, parent):
|
|
|
|
|
return AnnotationWalkerCtx(origin, args, layer, parent, self._allowed_types, self._allowed_annotations)
|
|
|
|
|
return AnnotationWalkerCtx(
|
|
|
|
|
origin, args, layer, parent, self._allowed_types, self._allowed_annotations
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _apply_triggers(self, method: str, ctx: AnnotationWalkerCtx) -> TriggerResult:
|
|
|
|
|
final = TriggerResult.passthrough()
|
|
|
|
|
for trig in self._triggers:
|
|
|
|
|
res = getattr(trig, method)(ctx)
|
|
|
|
|
for t in self._triggers:
|
|
|
|
|
res = getattr(t, method)(ctx)
|
|
|
|
|
if not res:
|
|
|
|
|
continue
|
|
|
|
|
if res.restart_with is not None:
|
|
|
|
|
@@ -311,33 +834,64 @@ class AnnotationWalker:
|
|
|
|
|
)
|
|
|
|
|
return final
|
|
|
|
|
|
|
|
|
|
def _apply_exits(self, ctx: AnnotationWalkerCtx) -> Any | None:
|
|
|
|
|
"""
|
|
|
|
|
Run exits in order: type-specific (if implemented) then generic `process_exit`.
|
|
|
|
|
Only `replace_with` is honored at exit; last `replace_with` wins.
|
|
|
|
|
"""
|
|
|
|
|
final = None
|
|
|
|
|
for t in self._triggers:
|
|
|
|
|
res = t.process_exit(ctx)
|
|
|
|
|
if res and (res.replace_with is not None):
|
|
|
|
|
final = res.replace_with
|
|
|
|
|
return final
|
|
|
|
|
|
|
|
|
|
def _handle_with_triggers(
|
|
|
|
|
self,
|
|
|
|
|
trigger_name: str,
|
|
|
|
|
ctx: AnnotationWalkerCtx,
|
|
|
|
|
args_handler: Callable[[AnnotationWalkerCtx], Any] | None = None,
|
|
|
|
|
) -> Any:
|
|
|
|
|
"""Generic handler: run triggers, maybe recurse into args with a custom handler."""
|
|
|
|
|
# ENTER
|
|
|
|
|
res = self._apply_triggers(trigger_name, ctx)
|
|
|
|
|
if res.restart_with is not None:
|
|
|
|
|
return self._walk(res.restart_with, ctx.parent)
|
|
|
|
|
if res.replace_with is not None:
|
|
|
|
|
return res.replace_with
|
|
|
|
|
exit_val = self._apply_exits(ctx)
|
|
|
|
|
return exit_val if exit_val is not None else res.replace_with
|
|
|
|
|
|
|
|
|
|
node_value = None
|
|
|
|
|
if not res.skip_children:
|
|
|
|
|
if args_handler:
|
|
|
|
|
return args_handler(ctx)
|
|
|
|
|
return tuple(self._walk(a, ctx) for a in ctx.args)
|
|
|
|
|
return None
|
|
|
|
|
node_value = args_handler(ctx)
|
|
|
|
|
else:
|
|
|
|
|
# DEFAULT: descend once per schema arg; mark as positional arg
|
|
|
|
|
node_value = tuple(
|
|
|
|
|
self._walk_child(a, ctx, arg_index=i, role="arg", token=i)
|
|
|
|
|
for i, a in enumerate(ctx.args)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# EXIT
|
|
|
|
|
exit_val = self._apply_exits(ctx)
|
|
|
|
|
return exit_val if exit_val is not None else node_value
|
|
|
|
|
|
|
|
|
|
def _walk_args_tuple(self, ctx: AnnotationWalkerCtx):
|
|
|
|
|
# special Ellipsis case for Tuple
|
|
|
|
|
# Tuple[T, ...] (variadic): one schema child, mark role='elem'
|
|
|
|
|
if len(ctx.args) == 2 and ctx.args[1] is Ellipsis:
|
|
|
|
|
return (self._walk(ctx.args[0], ctx), Ellipsis)
|
|
|
|
|
return tuple(self._walk(a, ctx) for a in ctx.args)
|
|
|
|
|
return (
|
|
|
|
|
self._walk_child(ctx.args[0], ctx, arg_index=0, role="elem", token=None),
|
|
|
|
|
Ellipsis,
|
|
|
|
|
)
|
|
|
|
|
# Fixed tuple: each positional arg gets role='arg'
|
|
|
|
|
return tuple(
|
|
|
|
|
self._walk_child(a, ctx, arg_index=i, role="arg", token=i)
|
|
|
|
|
for i, a in enumerate(ctx.args)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# --- Dispatcher ---
|
|
|
|
|
|
|
|
|
|
def _walk(self, type_: Any, parent_ctx: Optional[AnnotationWalkerCtx]) -> Any:
|
|
|
|
|
# For logs only: show the calling layer
|
|
|
|
|
print(f"[{parent_ctx.layer if parent_ctx else 0}] walking through: {type_}")
|
|
|
|
|
|
|
|
|
|
origin = get_origin(type_) or type_
|
|
|
|
|
@@ -349,31 +903,123 @@ class AnnotationWalker:
|
|
|
|
|
raise RuntimeError("Annotation must be using type(s), not instances")
|
|
|
|
|
|
|
|
|
|
args = get_args(type_)
|
|
|
|
|
layer = 0 if parent_ctx is None else parent_ctx.layer + 1
|
|
|
|
|
ctx = self._new_ctx(origin, args, layer, parent_ctx)
|
|
|
|
|
|
|
|
|
|
# IMPORTANT:
|
|
|
|
|
# If caller (_walk_child) already constructed a child ctx with edge metadata,
|
|
|
|
|
# reuse that object as *the* ctx for this node instead of allocating a new one.
|
|
|
|
|
if (
|
|
|
|
|
isinstance(parent_ctx, AnnotationWalkerCtx)
|
|
|
|
|
and parent_ctx.origin is origin
|
|
|
|
|
and parent_ctx.args == args
|
|
|
|
|
):
|
|
|
|
|
ctx = parent_ctx
|
|
|
|
|
else:
|
|
|
|
|
# Root or internal calls that didn't prebuild the ctx
|
|
|
|
|
layer = 0 if parent_ctx is None else parent_ctx.layer + 1
|
|
|
|
|
ctx = self._new_ctx(origin, args, layer, parent_ctx)
|
|
|
|
|
|
|
|
|
|
# Remember root ctx for end_trigger()
|
|
|
|
|
if ctx.parent is None:
|
|
|
|
|
self.root_ctx = ctx
|
|
|
|
|
|
|
|
|
|
print(origin)
|
|
|
|
|
match origin:
|
|
|
|
|
case typing.Annotated:
|
|
|
|
|
# inner type gets role='annotated'
|
|
|
|
|
return self._handle_with_triggers(
|
|
|
|
|
"process_annotated", ctx, args_handler=lambda c: self._walk(c.args[0], c) if c.args else None
|
|
|
|
|
"process_annotated",
|
|
|
|
|
ctx,
|
|
|
|
|
args_handler=lambda c: (
|
|
|
|
|
self._walk_child(c.args[0], c, arg_index=0, role="annotated", token=None)
|
|
|
|
|
if c.args
|
|
|
|
|
else None
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
case types.UnionType:
|
|
|
|
|
return self._handle_with_triggers("process_union", ctx)
|
|
|
|
|
# branches get role='branch' and token=branch index
|
|
|
|
|
return self._handle_with_triggers(
|
|
|
|
|
"process_union",
|
|
|
|
|
ctx,
|
|
|
|
|
args_handler=lambda c: tuple(
|
|
|
|
|
self._walk_child(a, c, arg_index=i, role="branch", token=i)
|
|
|
|
|
for i, a in enumerate(c.args)
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
case _ if issubclass(origin, dict):
|
|
|
|
|
return self._handle_with_triggers("process_dict", ctx)
|
|
|
|
|
# arg0=key (role='key'), arg1=value (role='val')
|
|
|
|
|
return self._handle_with_triggers(
|
|
|
|
|
"process_dict",
|
|
|
|
|
ctx,
|
|
|
|
|
args_handler=lambda c: (
|
|
|
|
|
self._walk_child(c.args[0], c, arg_index=0, role="key", token=None),
|
|
|
|
|
self._walk_child(c.args[1], c, arg_index=1, role="val", token=None),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
case _ if issubclass(origin, tuple):
|
|
|
|
|
return self._handle_with_triggers("process_tuple", ctx, self._walk_args_tuple)
|
|
|
|
|
|
|
|
|
|
case _ if issubclass(origin, list):
|
|
|
|
|
return self._handle_with_triggers("process_list", ctx)
|
|
|
|
|
# single child T with role='elem'
|
|
|
|
|
return self._handle_with_triggers(
|
|
|
|
|
"process_list",
|
|
|
|
|
ctx,
|
|
|
|
|
args_handler=lambda c: (
|
|
|
|
|
self._walk_child(c.args[0], c, arg_index=0, role="elem", token=None)
|
|
|
|
|
if c.args
|
|
|
|
|
else None
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
case _ if issubclass(origin, set):
|
|
|
|
|
return self._handle_with_triggers("process_set", ctx)
|
|
|
|
|
# single child T with role='elem'
|
|
|
|
|
return self._handle_with_triggers(
|
|
|
|
|
"process_set",
|
|
|
|
|
ctx,
|
|
|
|
|
args_handler=lambda c: (
|
|
|
|
|
self._walk_child(c.args[0], c, arg_index=0, role="elem", token=None)
|
|
|
|
|
if c.args
|
|
|
|
|
else None
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
case _ if origin in self._allowed_types:
|
|
|
|
|
return self._handle_with_triggers("process_allowed", ctx)
|
|
|
|
|
|
|
|
|
|
case _:
|
|
|
|
|
res = self._apply_triggers("process_unknown", ctx)
|
|
|
|
|
if res.restart_with is not None:
|
|
|
|
|
return self._walk(res.restart_with, ctx.parent)
|
|
|
|
|
if res.replace_with is not None:
|
|
|
|
|
return res.replace_with
|
|
|
|
|
raise UnsupportedFieldType(f"Not supported Field: {ctx.origin}, " f"Supported list: {self._allowed_types}")
|
|
|
|
|
raise UnsupportedFieldType(
|
|
|
|
|
f"Not supported Field: {ctx.origin}, Supported list: {self._allowed_types}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _walk_child(
|
|
|
|
|
self,
|
|
|
|
|
type_expr: Any,
|
|
|
|
|
parent_ctx: AnnotationWalkerCtx,
|
|
|
|
|
*,
|
|
|
|
|
arg_index: int,
|
|
|
|
|
role: str | None,
|
|
|
|
|
token: Any | None,
|
|
|
|
|
) -> Any:
|
|
|
|
|
origin = get_origin(type_expr) or type_expr
|
|
|
|
|
if origin is None:
|
|
|
|
|
origin = NoneType
|
|
|
|
|
if origin is Union:
|
|
|
|
|
origin = UnionType
|
|
|
|
|
if not isinstance(origin, type):
|
|
|
|
|
raise RuntimeError("Annotation must be using type(s), not instances")
|
|
|
|
|
|
|
|
|
|
args = get_args(type_expr)
|
|
|
|
|
child = self._new_ctx(origin, args, parent_ctx.layer + 1, parent_ctx)
|
|
|
|
|
# stamp routing metadata for triggers
|
|
|
|
|
child._set_edge(role=role, token=token, arg_index=arg_index)
|
|
|
|
|
|
|
|
|
|
# IMPORTANT: recurse with the CHILD as the parent_ctx for the next step,
|
|
|
|
|
# so the walker uses this child ctx (with edge metadata & incremented layer).
|
|
|
|
|
return self._walk(type_expr, child)
|
|
|
|
|
|