work
This commit is contained in:
@@ -79,6 +79,11 @@ class AnnotationWalkerCtx:
|
||||
self.__allowed_annotations: dict[str, Any] = allowed_annotations
|
||||
self.__ext: dict[Any, ChainMap] = {} # per-trigger namespaces (lazy)
|
||||
|
||||
# NEW: edge routing metadata (walker sets; triggers read)
|
||||
self.__edge_role: str | None = None # 'elem' | 'key' | 'val' | 'branch' | 'annotated' | 'arg'
|
||||
self.__edge_token: Any | None = None # index/key/branch-id/etc.
|
||||
self.__arg_index: int | None = None # which schema arg of the parent this child is
|
||||
|
||||
@property
|
||||
def origin(self) -> Any:
|
||||
return self.__origin
|
||||
@@ -99,6 +104,19 @@ class AnnotationWalkerCtx:
|
||||
def allowed_annotations(self) -> Mapping[str, Any]:
|
||||
return self.__allowed_annotations
|
||||
|
||||
# NEW: read-only edge metadata for routing inside triggers
|
||||
@property
|
||||
def edge_role(self) -> str | None:
|
||||
return self.__edge_role
|
||||
|
||||
@property
|
||||
def edge_token(self) -> Any | None:
|
||||
return self.__edge_token
|
||||
|
||||
@property
|
||||
def arg_index(self) -> int | None:
|
||||
return self.__arg_index
|
||||
|
||||
def ns(self, owner: Any) -> ChainMap:
|
||||
"""
|
||||
A per-trigger overlay namespace that inherits from parent ctx.
|
||||
@@ -107,11 +125,30 @@ class AnnotationWalkerCtx:
|
||||
"""
|
||||
if owner in self.__ext:
|
||||
return self.__ext[owner]
|
||||
parent_map = self.__parent.__ext.get(owner) if (self.__parent and hasattr(self.__parent, "_AnnotationWalkerCtx__ext")) else {}
|
||||
cm = ChainMap({}, parent_map if isinstance(parent_map, ChainMap) else dict(parent_map))
|
||||
|
||||
# Determine the parent chain for this owner
|
||||
parent_chain = ()
|
||||
if self.__parent is not None and hasattr(self.__parent, "_AnnotationWalkerCtx__ext"):
|
||||
parent_map = self.__parent._AnnotationWalkerCtx__ext.get(owner) # access private attr intentionally
|
||||
if isinstance(parent_map, ChainMap):
|
||||
# IMPORTANT: preserve the whole chain (not the ChainMap object as a single mapping)
|
||||
parent_chain = tuple(parent_map.maps)
|
||||
elif isinstance(parent_map, dict):
|
||||
parent_chain = (parent_map,)
|
||||
else:
|
||||
parent_chain = ()
|
||||
|
||||
# Build a new ChainMap whose first map is this node's local overlay
|
||||
cm = ChainMap({}, *parent_chain)
|
||||
self.__ext[owner] = cm
|
||||
return cm
|
||||
|
||||
# INTERNAL (walker only): set edge metadata on a child ctx
|
||||
def _set_edge(self, *, role: str | None, token: Any | None, arg_index: int | None) -> None:
|
||||
self.__edge_role = role
|
||||
self.__edge_token = token
|
||||
self.__arg_index = arg_index
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TriggerResult:
|
||||
@@ -169,6 +206,12 @@ class AnnotationTrigger:
|
||||
def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
return None
|
||||
|
||||
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
return None
|
||||
|
||||
def end_trigger(self, ctx: Optional[AnnotationWalkerCtx]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class LAMSchemaValidation(AnnotationTrigger):
|
||||
def init_trigger(self) -> None:
|
||||
@@ -224,6 +267,462 @@ class LAMSchemaValidation(AnnotationTrigger):
|
||||
raise IncompletelyAnnotatedField(f"None is only accepted with Union, to implement Optional[]")
|
||||
return None
|
||||
|
||||
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
print(f"process_exit: {ctx.origin}")
|
||||
return None
|
||||
|
||||
def end_trigger(self, ctx: AnnotationWalkerCtx):
|
||||
pass
|
||||
|
||||
|
||||
class DataValidation(AnnotationTrigger):
|
||||
def __init__(self, value: Any) -> None:
|
||||
self._root = value
|
||||
|
||||
def init_trigger(self) -> None:
|
||||
self._seeded = False
|
||||
|
||||
def _bag(self, ctx: AnnotationWalkerCtx):
|
||||
bag = ctx.ns(self)
|
||||
if not self._seeded:
|
||||
bag["value"] = self._root
|
||||
bag["path"] = ()
|
||||
self._seeded = True
|
||||
return bag
|
||||
|
||||
def process_annotated(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
b = self._bag(ctx)
|
||||
|
||||
def process_union(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
b = self._bag(ctx)
|
||||
|
||||
def process_dict(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
b = self._bag(ctx)
|
||||
|
||||
def process_tuple(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
b = self._bag(ctx)
|
||||
|
||||
def process_list(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
b = self._bag(ctx)
|
||||
|
||||
def process_set(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
b = self._bag(ctx)
|
||||
|
||||
def process_unknown(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
b = self._bag(ctx)
|
||||
|
||||
def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
b = self._bag(ctx)
|
||||
|
||||
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
|
||||
b = self._bag(ctx)
|
||||
|
||||
def end_trigger(self, ctx: AnnotationWalkerCtx):
|
||||
pass
|
||||
|
||||
|
||||
class SchemaValidationError(TypeError):
|
||||
def __init__(self, path: tuple[Any, ...], msg: str):
|
||||
dotted = "".join(f"[{p}]" if isinstance(p, int) else (f".{p}" if path and i else str(p)) for i, p in enumerate(path)) or "<root>"
|
||||
super().__init__(f"{dotted}: {msg}")
|
||||
self.path = path
|
||||
|
||||
|
||||
class HorizontalValidationTrigger(AnnotationTrigger):
|
||||
def __init__(
|
||||
self,
|
||||
value: Any,
|
||||
*,
|
||||
strict_bool: bool = False,
|
||||
collect_all: bool = True,
|
||||
union_summary_max_per_branch: int = 3, # NEW
|
||||
show_branch_types: bool = True, # NEW
|
||||
):
|
||||
self._root_value = value
|
||||
self._strict_bool = strict_bool
|
||||
self._collect_all = collect_all
|
||||
self._union_summary_max = union_summary_max_per_branch
|
||||
self._show_branch_types = show_branch_types
|
||||
self._seeded = False
|
||||
self._id_counter = 0
|
||||
|
||||
# ---------- utilities ----------
|
||||
|
||||
# ---------- utilities ----------
|
||||
|
||||
def _mk_cand(self, value: Any, path: tuple[Any, ...]):
|
||||
cid = self._id_counter
|
||||
self._id_counter += 1
|
||||
return {"id": cid, "value": value, "path": path}
|
||||
|
||||
def _spawn_from(self, bag, parent_cand: dict, value: Any, path: tuple[Any, ...]):
|
||||
"""
|
||||
Create a child candidate inheriting union root id when inside a union branch.
|
||||
"""
|
||||
nc = self._mk_cand(value, path)
|
||||
if "__union_branch_id" in bag:
|
||||
# Ensure every descendant carries union-root id
|
||||
nc["union_cid"] = parent_cand.get("union_cid", parent_cand["id"])
|
||||
elif "union_cid" in parent_cand:
|
||||
# Preserve if already present (e.g., nested under a branch)
|
||||
nc["union_cid"] = parent_cand["union_cid"]
|
||||
return nc
|
||||
|
||||
def _bag(self, ctx: AnnotationWalkerCtx):
|
||||
bag = ctx.ns(self)
|
||||
local = bag.maps[0]
|
||||
parent_bag = ctx.parent.ns(self) if ctx.parent is not None else None
|
||||
|
||||
if not self._seeded and ctx.parent is None:
|
||||
local["candidates"] = [self._mk_cand(self._root_value, ())]
|
||||
local["errors"] = []
|
||||
self._seeded = True
|
||||
|
||||
if "errors" not in local and parent_bag is not None:
|
||||
local["errors"] = parent_bag.get("errors", [])
|
||||
|
||||
if ctx.parent is not None:
|
||||
role = ctx.edge_role
|
||||
if role == "elem":
|
||||
local["candidates"] = list(parent_bag.get("pending_elem", []))
|
||||
elif role == "key":
|
||||
local["candidates"] = list(parent_bag.get("pending_key", []))
|
||||
elif role == "val":
|
||||
local["candidates"] = list(parent_bag.get("pending_val", []))
|
||||
elif role == "arg":
|
||||
local["candidates"] = list(parent_bag.get(f"pending_arg_{ctx.arg_index}", []))
|
||||
elif role in ("branch", "annotated"):
|
||||
local["candidates"] = list(parent_bag.get("candidates", []))
|
||||
else:
|
||||
local["candidates"] = list(parent_bag.get("candidates", []))
|
||||
|
||||
# If immediate child of a Union (branch), switch to branch-local errors
|
||||
if ctx.parent is not None and ctx.parent.origin is UnionType and ctx.edge_role == "branch":
|
||||
ubag = ctx.parent.ns(self)
|
||||
branches = ubag.setdefault("union_branches_state", {})
|
||||
bid = ctx.edge_token
|
||||
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
|
||||
local["errors"] = bstate["errors"]
|
||||
local["__union_branch_id"] = bid
|
||||
local["__union_state_ref"] = ubag
|
||||
# NEW: propagate the union scope id to this branch
|
||||
local["__union_scope_id"] = ubag.get("__union_scope_id")
|
||||
|
||||
# Stamp union_cid for all candidates at branch entry
|
||||
stamped = []
|
||||
for c in local.get("candidates", []):
|
||||
c2 = dict(c)
|
||||
c2.setdefault("union_cid", c2["id"])
|
||||
stamped.append(c2)
|
||||
local["candidates"] = stamped
|
||||
|
||||
return bag
|
||||
|
||||
def _mark_branch_fail(self, bag, cand: dict):
|
||||
if "__union_branch_id" in bag and "__union_state_ref" in bag:
|
||||
bid = bag["__union_branch_id"]
|
||||
uref = bag["__union_state_ref"]
|
||||
branches = uref.setdefault("union_branches_state", {})
|
||||
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
|
||||
root_id = cand.get("union_cid", cand["id"])
|
||||
bstate["failed_ids"].add(root_id)
|
||||
|
||||
def _err(self, bag, cand, msg: str):
|
||||
"""
|
||||
Record an error. If we are inside a union branch, attach (cid, scope) so the
|
||||
outer union can attribute it correctly to its branch/candidate.
|
||||
"""
|
||||
e = SchemaValidationError(cand["path"], msg)
|
||||
|
||||
# Find nearest union scope (walk ChainMap parents if needed)
|
||||
def nearest_union_scope(b):
|
||||
if "__union_scope_id" in b:
|
||||
return b["__union_scope_id"]
|
||||
# ChainMap.get() already searches parents; use get() with a sentinel
|
||||
sentinel = object()
|
||||
v = b.get("__union_scope_id", sentinel)
|
||||
return None if v is sentinel else v
|
||||
|
||||
in_union_branch = ("__union_branch_id" in bag) and ("__union_state_ref" in bag)
|
||||
if in_union_branch:
|
||||
# union-root candidate id for attribution (propagated by spawn_from / branch entry)
|
||||
root_id = cand.get("union_cid", cand["id"])
|
||||
scope = nearest_union_scope(bag)
|
||||
# Branch-local error bucket lives under union state
|
||||
uref = bag["__union_state_ref"]
|
||||
bid = bag["__union_branch_id"]
|
||||
branches = uref.setdefault("union_branches_state", {})
|
||||
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
|
||||
bstate["failed_ids"].add(root_id)
|
||||
bstate["errors"].append({"cid": root_id, "scope": scope, "err": e})
|
||||
return
|
||||
|
||||
# Outside any union branch: add to global errors (and optional fail-fast)
|
||||
bag["errors"].append(e)
|
||||
if not self._collect_all:
|
||||
raise e
|
||||
|
||||
def _check_leaf(self, bag, cand: dict, T: type):
|
||||
v = cand["value"]
|
||||
if T is type(None):
|
||||
if v is not None:
|
||||
self._err(bag, cand, "expected None")
|
||||
return
|
||||
if self._strict_bool and T is bool and type(v) is not bool:
|
||||
self._err(bag, cand, f"expected bool, got {type(v).__name__}")
|
||||
return
|
||||
if v is None or not isinstance(v, T):
|
||||
got = "None" if v is None else type(v).__name__
|
||||
self._err(bag, cand, f"expected {T.__name__}, got {got}")
|
||||
|
||||
# ---------- entry hooks (seed pending; validate leaves; no aggregation here) ----------
|
||||
|
||||
def process_allowed(self, ctx: AnnotationWalkerCtx):
|
||||
bag = self._bag(ctx)
|
||||
T = ctx.origin
|
||||
for cand in bag.get("candidates", []):
|
||||
self._check_leaf(bag, cand, T)
|
||||
|
||||
def process_list(self, ctx: AnnotationWalkerCtx):
|
||||
bag = self._bag(ctx)
|
||||
cands = bag.get("candidates", [])
|
||||
if len(ctx.args) != 1:
|
||||
for cand in cands:
|
||||
self._err(bag, cand, "List[T] requires 1 argument")
|
||||
return
|
||||
pending = []
|
||||
for cand in cands:
|
||||
v = cand["value"]
|
||||
if not isinstance(v, list):
|
||||
self._err(bag, cand, f"expected list, got {type(v).__name__}")
|
||||
continue
|
||||
base = cand["path"]
|
||||
for i in range(len(v)):
|
||||
pending.append(self._spawn_from(bag, cand, v[i], base + (i,)))
|
||||
bag["pending_elem"] = pending
|
||||
|
||||
def process_tuple(self, ctx: AnnotationWalkerCtx):
|
||||
bag = self._bag(ctx)
|
||||
cands = bag.get("candidates", [])
|
||||
vararg = len(ctx.args) == 2 and ctx.args[1] is Ellipsis
|
||||
if vararg:
|
||||
pending = []
|
||||
for cand in cands:
|
||||
v = cand["value"]
|
||||
if not isinstance(v, tuple):
|
||||
self._err(bag, cand, f"expected tuple, got {type(v).__name__}")
|
||||
continue
|
||||
base = cand["path"]
|
||||
for i in range(len(v)):
|
||||
pending.append(self._spawn_from(bag, cand, v[i], base + (i,)))
|
||||
bag["pending_elem"] = pending
|
||||
else:
|
||||
arity = len(ctx.args)
|
||||
slots: dict[int, list] = {}
|
||||
for cand in cands:
|
||||
v = cand["value"]
|
||||
if not isinstance(v, tuple):
|
||||
self._err(bag, cand, f"expected tuple, got {type(v).__name__}")
|
||||
continue
|
||||
if len(v) != arity:
|
||||
self._err(bag, cand, f"expected tuple len {arity}, got {len(v)}")
|
||||
continue
|
||||
base = cand["path"]
|
||||
for i, elem in enumerate(v):
|
||||
slots.setdefault(i, []).append(self._spawn_from(bag, cand, elem, base + (i,)))
|
||||
for i, group in slots.items():
|
||||
bag[f"pending_arg_{i}"] = group
|
||||
|
||||
def process_set(self, ctx: AnnotationWalkerCtx):
|
||||
bag = self._bag(ctx)
|
||||
cands = bag.get("candidates", [])
|
||||
if len(ctx.args) != 1:
|
||||
for cand in cands:
|
||||
self._err(bag, cand, "Set[T]/FrozenSet[T] requires 1 argument")
|
||||
return
|
||||
pending = []
|
||||
for cand in cands:
|
||||
v = cand["value"]
|
||||
if not isinstance(v, (set, frozenset)):
|
||||
self._err(bag, cand, f"expected set/frozenset, got {type(v).__name__}")
|
||||
continue
|
||||
base = cand["path"]
|
||||
for e in v:
|
||||
pending.append(self._spawn_from(bag, cand, e, base + ("<elem>",)))
|
||||
bag["pending_elem"] = pending
|
||||
|
||||
def process_dict(self, ctx: AnnotationWalkerCtx):
|
||||
bag = self._bag(ctx)
|
||||
cands = bag.get("candidates", [])
|
||||
if len(ctx.args) != 2:
|
||||
for cand in cands:
|
||||
self._err(bag, cand, "Dict[K,V] requires 2 arguments")
|
||||
return
|
||||
pkeys, pvals = [], []
|
||||
for cand in cands:
|
||||
v = cand["value"]
|
||||
if not isinstance(v, dict):
|
||||
self._err(bag, cand, f"expected dict, got {type(v).__name__}")
|
||||
continue
|
||||
base = cand["path"]
|
||||
for k, val in v.items():
|
||||
pkeys.append(self._spawn_from(bag, cand, k, base + ("<key>",)))
|
||||
pvals.append(self._spawn_from(bag, cand, val, base + (k,)))
|
||||
bag["pending_key"] = pkeys
|
||||
bag["pending_val"] = pvals
|
||||
|
||||
def process_annotated(self, ctx: AnnotationWalkerCtx):
|
||||
# No checks here; inner T will validate routed candidates.
|
||||
self._bag(ctx)
|
||||
|
||||
def process_union(self, ctx: AnnotationWalkerCtx):
|
||||
bag = self._bag(ctx)
|
||||
# Candidates arriving at THIS union node (e.g., per list element when union is element type)
|
||||
cands = bag.get("candidates", [])
|
||||
# Track these exact candidate ids for this union scope
|
||||
bag["union_candidate_ids"] = [c["id"] for c in cands]
|
||||
# One entry per branch; each branch accumulates scoped errors keyed by union-root cid
|
||||
bag.setdefault("union_branches_state", {}) # bid -> {"failed_ids": set(), "errors": []}
|
||||
bag["union_branch_labels"] = [self._pretty_type(a) for a in ctx.args]
|
||||
# Scope id distinguishes nested unions; only errors tagged with this scope count here
|
||||
bag["__union_scope_id"] = id(ctx)
|
||||
|
||||
# ---------- exit (aggregation only) ----------
|
||||
|
||||
def process_exit(self, ctx: AnnotationWalkerCtx):
|
||||
if ctx.origin is not UnionType:
|
||||
return None
|
||||
|
||||
bag = ctx.ns(self)
|
||||
root_errors = bag.get("errors", [])
|
||||
branches = bag.get("union_branches_state", {})
|
||||
cand_ids = bag.get("union_candidate_ids", [])
|
||||
labels = bag.get("union_branch_labels", [])
|
||||
scope_id = bag.get("__union_scope_id") # THIS union's scope
|
||||
|
||||
# Build: per_branch_errors[bid][cid] -> [SchemaValidationError, ...] filtered to THIS union scope
|
||||
per_branch_errors: dict[int, dict[int, list[SchemaValidationError]]] = {}
|
||||
for bid, state in branches.items():
|
||||
berrs = state.get("errors", [])
|
||||
m: dict[int, list[SchemaValidationError]] = {}
|
||||
for item in berrs:
|
||||
if isinstance(item, dict):
|
||||
# Only errors tagged for THIS union scope
|
||||
if item.get("scope") != scope_id:
|
||||
continue
|
||||
cid = item.get("cid")
|
||||
err = item.get("err")
|
||||
if cid is not None and isinstance(err, SchemaValidationError):
|
||||
m.setdefault(cid, []).append(err)
|
||||
elif isinstance(item, SchemaValidationError):
|
||||
# Legacy/plain error: attribute to all failed ids of this branch
|
||||
for cid in state.get("failed_ids", set()):
|
||||
m.setdefault(cid, []).append(item)
|
||||
# Dedupe identical (path, text)
|
||||
for cid, lst in m.items():
|
||||
seen = set()
|
||||
uniq = []
|
||||
for e in lst:
|
||||
key = (e.path, str(e))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
uniq.append(e)
|
||||
m[cid] = uniq
|
||||
per_branch_errors[bid] = m
|
||||
|
||||
# Decide per candidate of THIS union
|
||||
for cid in cand_ids:
|
||||
branch_ok = any(len(per_branch_errors.get(bid, {}).get(cid, [])) == 0 for bid in branches.keys())
|
||||
if branch_ok:
|
||||
continue # some branch matched, good
|
||||
|
||||
# Pretty summary for this failing candidate
|
||||
# Get the path for this union's candidate; fall back to <root> if absent
|
||||
cands_here = bag.get("candidates", [])
|
||||
path = next((c["path"] for c in cands_here if c.get("id") == cid), ())
|
||||
|
||||
lines = ["no union branch matched; tried:"]
|
||||
for bid, state in branches.items():
|
||||
label = labels[bid] if bid < len(labels) else f"branch {bid}"
|
||||
errs = per_branch_errors.get(bid, {}).get(cid, [])
|
||||
if not errs:
|
||||
lines.append(f" - {label}: mismatch")
|
||||
continue
|
||||
N = self._union_summary_max
|
||||
shown = errs[:N]
|
||||
lines.append(f" - {label}: {len(errs)} issue(s)")
|
||||
for e in shown:
|
||||
lines.append(f" - {e}")
|
||||
if len(errs) > N:
|
||||
lines.append(f" - (+{len(errs) - N} more)")
|
||||
|
||||
root_errors.append(SchemaValidationError(path, "\n".join(lines)))
|
||||
|
||||
# Also, if THIS union is nested inside an enclosing union branch, and ALL branches
|
||||
# of THIS union failed for a candidate, bubble a single 'mismatch' up to that parent
|
||||
# so the parent union can mark its branch as failed for that union-root cid.
|
||||
if ("__union_state_ref" in bag) and ("__union_branch_id" in bag):
|
||||
uref = bag["__union_state_ref"] # enclosing union's bag
|
||||
up_bid = bag["__union_branch_id"] # which branch we're in
|
||||
up_branches = uref.setdefault("union_branches_state", {})
|
||||
up_state = up_branches.setdefault(up_bid, {"failed_ids": set(), "errors": []})
|
||||
up_scope = uref.get("__union_scope_id")
|
||||
|
||||
# Map THIS union's local cid to the enclosing union's root cid via union_cid
|
||||
local_cands = bag.get("candidates", [])
|
||||
# Fallback map: if local_cands missing, assume 1:1
|
||||
fallback_map = {cid: cid for cid in cand_ids}
|
||||
map_local_to_outer = {}
|
||||
for lc in local_cands:
|
||||
local_id = lc.get("id")
|
||||
outer_root = lc.get("union_cid", local_id)
|
||||
map_local_to_outer[local_id] = outer_root
|
||||
|
||||
for cid in cand_ids:
|
||||
branch_ok = any(len(per_branch_errors.get(bid, {}).get(cid, [])) == 0 for bid in branches.keys())
|
||||
if branch_ok:
|
||||
continue # do not bubble success
|
||||
|
||||
outer_cid = map_local_to_outer.get(cid, fallback_map[cid])
|
||||
up_state["failed_ids"].add(outer_cid)
|
||||
# lightweight marker so parent has at least one error on record for this cid
|
||||
up_state["errors"].append({"cid": outer_cid, "scope": up_scope, "err": SchemaValidationError((), "mismatch")})
|
||||
|
||||
return None
|
||||
|
||||
def end_trigger(self, ctx: Optional[AnnotationWalkerCtx]):
|
||||
if ctx is None:
|
||||
return
|
||||
errors = ctx.ns(self).get("errors", [])
|
||||
if errors:
|
||||
# Raise them together; swap for your preferred error carrier if needed
|
||||
raise ExceptionGroup("schema validation failed", errors)
|
||||
|
||||
# --- pretty type for messages (best-effort) ---
|
||||
def _pretty_type(self, t: Any) -> str:
|
||||
origin = get_origin(t) or t
|
||||
args = get_args(t)
|
||||
try:
|
||||
if origin is UnionType or origin is Union:
|
||||
return " | ".join(self._pretty_type(a) for a in args)
|
||||
if origin in (list, tuple, set, frozenset, dict, Annotated):
|
||||
if origin is dict and len(args) == 2:
|
||||
return f"dict[{self._pretty_type(args[0])}, {self._pretty_type(args[1])}]"
|
||||
if origin in (list, set, frozenset) and len(args) == 1:
|
||||
name = "list" if origin is list else ("set" if origin is set else "frozenset")
|
||||
return f"{name}[{self._pretty_type(args[0])}]"
|
||||
if origin is tuple:
|
||||
if len(args) == 2 and args[1] is Ellipsis:
|
||||
return f"tuple[{self._pretty_type(args[0])}, ...]"
|
||||
return f"tuple[{', '.join(self._pretty_type(a) for a in args)}]"
|
||||
if origin is Annotated and args:
|
||||
return f"Annotated[{self._pretty_type(args[0])}, ...]"
|
||||
if isinstance(origin, type):
|
||||
return origin.__name__
|
||||
return str(t)
|
||||
except Exception:
|
||||
return repr(t)
|
||||
|
||||
|
||||
class AnnotationWalker:
|
||||
DEFAULT_ALLOWED_TYPES = frozenset({str, int, float, complex, bool, bytes, NoneType})
|
||||
@@ -235,7 +734,7 @@ class AnnotationWalker:
|
||||
"Dict": Dict,
|
||||
"Tuple": Tuple,
|
||||
"Set": Set,
|
||||
"FrozenSet": FrozenSet,
|
||||
# "FrozenSet": FrozenSet,
|
||||
"Annotated": Annotated,
|
||||
# builtins:
|
||||
"int": int,
|
||||
@@ -248,7 +747,7 @@ class AnnotationWalker:
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"set": set,
|
||||
"frozenset": frozenset,
|
||||
# "frozenset": frozenset,
|
||||
"tuple": tuple,
|
||||
}
|
||||
)
|
||||
@@ -285,9 +784,12 @@ class AnnotationWalker:
|
||||
self.__ann = eval(ann, {"__builtins__": {}}, self._allowed_annotations)
|
||||
|
||||
def run(self) -> TriggerResult:
|
||||
for trigger in self._triggers:
|
||||
trigger.init_trigger()
|
||||
return self._walk(self.__ann, None)
|
||||
for t in self._triggers:
|
||||
t.init_trigger()
|
||||
self.root_ctx = None
|
||||
self._walk(self.__ann, None)
|
||||
for t in self._triggers:
|
||||
t.end_trigger(self.root_ctx)
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
@@ -296,8 +798,8 @@ class AnnotationWalker:
|
||||
|
||||
def _apply_triggers(self, method: str, ctx: AnnotationWalkerCtx) -> TriggerResult:
|
||||
final = TriggerResult.passthrough()
|
||||
for trig in self._triggers:
|
||||
res = getattr(trig, method)(ctx)
|
||||
for t in self._triggers:
|
||||
res = getattr(t, method)(ctx)
|
||||
if not res:
|
||||
continue
|
||||
if res.restart_with is not None:
|
||||
@@ -311,33 +813,55 @@ class AnnotationWalker:
|
||||
)
|
||||
return final
|
||||
|
||||
def _apply_exits(self, ctx: AnnotationWalkerCtx) -> Any | None:
|
||||
"""
|
||||
Run exits in order: type-specific (if implemented) then generic `process_exit`.
|
||||
Only `replace_with` is honored at exit; last `replace_with` wins.
|
||||
"""
|
||||
final = None
|
||||
for t in self._triggers:
|
||||
res = t.process_exit(ctx)
|
||||
if res and (res.replace_with is not None):
|
||||
final = res.replace_with
|
||||
return final
|
||||
|
||||
def _handle_with_triggers(
|
||||
self,
|
||||
trigger_name: str,
|
||||
ctx: AnnotationWalkerCtx,
|
||||
args_handler: Callable[[AnnotationWalkerCtx], Any] | None = None,
|
||||
) -> Any:
|
||||
"""Generic handler: run triggers, maybe recurse into args with a custom handler."""
|
||||
# ENTER
|
||||
res = self._apply_triggers(trigger_name, ctx)
|
||||
if res.restart_with is not None:
|
||||
return self._walk(res.restart_with, ctx.parent)
|
||||
if res.replace_with is not None:
|
||||
return res.replace_with
|
||||
exit_val = self._apply_exits(ctx)
|
||||
return exit_val if exit_val is not None else res.replace_with
|
||||
|
||||
node_value = None
|
||||
if not res.skip_children:
|
||||
if args_handler:
|
||||
return args_handler(ctx)
|
||||
return tuple(self._walk(a, ctx) for a in ctx.args)
|
||||
return None
|
||||
node_value = args_handler(ctx)
|
||||
else:
|
||||
# DEFAULT: descend once per schema arg; mark as positional arg
|
||||
node_value = tuple(self._walk_child(a, ctx, arg_index=i, role="arg", token=i) for i, a in enumerate(ctx.args))
|
||||
|
||||
# EXIT
|
||||
exit_val = self._apply_exits(ctx)
|
||||
return exit_val if exit_val is not None else node_value
|
||||
|
||||
def _walk_args_tuple(self, ctx: AnnotationWalkerCtx):
|
||||
# special Ellipsis case for Tuple
|
||||
# Tuple[T, ...] (variadic): one schema child, mark role='elem'
|
||||
if len(ctx.args) == 2 and ctx.args[1] is Ellipsis:
|
||||
return (self._walk(ctx.args[0], ctx), Ellipsis)
|
||||
return tuple(self._walk(a, ctx) for a in ctx.args)
|
||||
return (self._walk_child(ctx.args[0], ctx, arg_index=0, role="elem", token=None), Ellipsis)
|
||||
# Fixed tuple: each positional arg gets role='arg'
|
||||
return tuple(self._walk_child(a, ctx, arg_index=i, role="arg", token=i) for i, a in enumerate(ctx.args))
|
||||
|
||||
# --- Dispatcher ---
|
||||
|
||||
def _walk(self, type_: Any, parent_ctx: Optional[AnnotationWalkerCtx]) -> Any:
|
||||
# For logs only: show the calling layer
|
||||
print(f"[{parent_ctx.layer if parent_ctx else 0}] walking through: {type_}")
|
||||
|
||||
origin = get_origin(type_) or type_
|
||||
@@ -349,31 +873,102 @@ class AnnotationWalker:
|
||||
raise RuntimeError("Annotation must be using type(s), not instances")
|
||||
|
||||
args = get_args(type_)
|
||||
|
||||
# IMPORTANT:
|
||||
# If caller (_walk_child) already constructed a child ctx with edge metadata,
|
||||
# reuse that object as *the* ctx for this node instead of allocating a new one.
|
||||
if isinstance(parent_ctx, AnnotationWalkerCtx) and parent_ctx.origin is origin and parent_ctx.args == args:
|
||||
ctx = parent_ctx
|
||||
else:
|
||||
# Root or internal calls that didn't prebuild the ctx
|
||||
layer = 0 if parent_ctx is None else parent_ctx.layer + 1
|
||||
ctx = self._new_ctx(origin, args, layer, parent_ctx)
|
||||
|
||||
# Remember root ctx for end_trigger()
|
||||
if ctx.parent is None:
|
||||
self.root_ctx = ctx
|
||||
|
||||
print(origin)
|
||||
match origin:
|
||||
case typing.Annotated:
|
||||
# inner type gets role='annotated'
|
||||
return self._handle_with_triggers(
|
||||
"process_annotated", ctx, args_handler=lambda c: self._walk(c.args[0], c) if c.args else None
|
||||
"process_annotated",
|
||||
ctx,
|
||||
args_handler=lambda c: self._walk_child(c.args[0], c, arg_index=0, role="annotated", token=None) if c.args else None,
|
||||
)
|
||||
|
||||
case types.UnionType:
|
||||
return self._handle_with_triggers("process_union", ctx)
|
||||
# branches get role='branch' and token=branch index
|
||||
return self._handle_with_triggers(
|
||||
"process_union",
|
||||
ctx,
|
||||
args_handler=lambda c: tuple(self._walk_child(a, c, arg_index=i, role="branch", token=i) for i, a in enumerate(c.args)),
|
||||
)
|
||||
|
||||
case _ if issubclass(origin, dict):
|
||||
return self._handle_with_triggers("process_dict", ctx)
|
||||
# arg0=key (role='key'), arg1=value (role='val')
|
||||
return self._handle_with_triggers(
|
||||
"process_dict",
|
||||
ctx,
|
||||
args_handler=lambda c: (
|
||||
self._walk_child(c.args[0], c, arg_index=0, role="key", token=None),
|
||||
self._walk_child(c.args[1], c, arg_index=1, role="val", token=None),
|
||||
),
|
||||
)
|
||||
|
||||
case _ if issubclass(origin, tuple):
|
||||
return self._handle_with_triggers("process_tuple", ctx, self._walk_args_tuple)
|
||||
|
||||
case _ if issubclass(origin, list):
|
||||
return self._handle_with_triggers("process_list", ctx)
|
||||
# single child T with role='elem'
|
||||
return self._handle_with_triggers(
|
||||
"process_list",
|
||||
ctx,
|
||||
args_handler=lambda c: self._walk_child(c.args[0], c, arg_index=0, role="elem", token=None) if c.args else None,
|
||||
)
|
||||
|
||||
case _ if issubclass(origin, set):
|
||||
return self._handle_with_triggers("process_set", ctx)
|
||||
# single child T with role='elem'
|
||||
return self._handle_with_triggers(
|
||||
"process_set",
|
||||
ctx,
|
||||
args_handler=lambda c: self._walk_child(c.args[0], c, arg_index=0, role="elem", token=None) if c.args else None,
|
||||
)
|
||||
|
||||
case _ if origin in self._allowed_types:
|
||||
return self._handle_with_triggers("process_allowed", ctx)
|
||||
|
||||
case _:
|
||||
res = self._apply_triggers("process_unknown", ctx)
|
||||
if res.restart_with is not None:
|
||||
return self._walk(res.restart_with, ctx.parent)
|
||||
if res.replace_with is not None:
|
||||
return res.replace_with
|
||||
raise UnsupportedFieldType(f"Not supported Field: {ctx.origin}, " f"Supported list: {self._allowed_types}")
|
||||
raise UnsupportedFieldType(f"Not supported Field: {ctx.origin}, Supported list: {self._allowed_types}")
|
||||
|
||||
def _walk_child(
|
||||
self,
|
||||
type_expr: Any,
|
||||
parent_ctx: AnnotationWalkerCtx,
|
||||
*,
|
||||
arg_index: int,
|
||||
role: str | None,
|
||||
token: Any | None,
|
||||
) -> Any:
|
||||
origin = get_origin(type_expr) or type_expr
|
||||
if origin is None:
|
||||
origin = NoneType
|
||||
if origin is Union:
|
||||
origin = UnionType
|
||||
if not isinstance(origin, type):
|
||||
raise RuntimeError("Annotation must be using type(s), not instances")
|
||||
|
||||
args = get_args(type_expr)
|
||||
child = self._new_ctx(origin, args, parent_ctx.layer + 1, parent_ctx)
|
||||
# stamp routing metadata for triggers
|
||||
child._set_edge(role=role, token=token, arg_index=arg_index)
|
||||
|
||||
# IMPORTANT: recurse with the CHILD as the parent_ctx for the next step,
|
||||
# so the walker uses this child ctx (with edge metadata & incremented layer).
|
||||
return self._walk(type_expr, child)
|
||||
|
||||
@@ -22,11 +22,18 @@ testdir_path = Path(__file__).parent.resolve()
|
||||
chdir(testdir_path.parent.resolve())
|
||||
|
||||
|
||||
class ElementTest(unittest.TestCase):
|
||||
class AnnotationsWalkerTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
print("\n->", unittest.TestCase.id(self))
|
||||
|
||||
def test_element_simple(self):
|
||||
def test_validate(self):
|
||||
ann = dict[int, list[int]] | dict[int, list[int | str]]
|
||||
val = {1: [2], 2: ["a", [1]]}
|
||||
|
||||
res = dm.tools.AnnotationWalker(ann, (dm.tools.HorizontalValidationTrigger(val),))
|
||||
res.run()
|
||||
|
||||
def test_simple(self):
|
||||
print(isinstance(None, type(None)))
|
||||
print("\n== From OBJs ==")
|
||||
res = dm.tools.AnnotationWalker(Annotated[Optional[dict[int, list[str]]], "comment"], (dm.tools.LAMSchemaValidation(),))
|
||||
|
||||
Reference in New Issue
Block a user