Compare commits

...

1 Commits

Author SHA1 Message Date
cclecle
25d5339946 work 2025-09-28 21:18:35 +02:00
2 changed files with 629 additions and 27 deletions

View File

@@ -79,6 +79,11 @@ class AnnotationWalkerCtx:
self.__allowed_annotations: dict[str, Any] = allowed_annotations
self.__ext: dict[Any, ChainMap] = {} # per-trigger namespaces (lazy)
# NEW: edge routing metadata (walker sets; triggers read)
self.__edge_role: str | None = None # 'elem' | 'key' | 'val' | 'branch' | 'annotated' | 'arg'
self.__edge_token: Any | None = None # index/key/branch-id/etc.
self.__arg_index: int | None = None # which schema arg of the parent this child is
@property
def origin(self) -> Any:
return self.__origin
@@ -99,6 +104,19 @@ class AnnotationWalkerCtx:
def allowed_annotations(self) -> Mapping[str, Any]:
return self.__allowed_annotations
# NEW: read-only edge metadata for routing inside triggers
@property
def edge_role(self) -> str | None:
return self.__edge_role
@property
def edge_token(self) -> Any | None:
return self.__edge_token
@property
def arg_index(self) -> int | None:
return self.__arg_index
def ns(self, owner: Any) -> ChainMap:
"""
A per-trigger overlay namespace that inherits from parent ctx.
@@ -107,11 +125,30 @@ class AnnotationWalkerCtx:
"""
if owner in self.__ext:
return self.__ext[owner]
parent_map = self.__parent.__ext.get(owner) if (self.__parent and hasattr(self.__parent, "_AnnotationWalkerCtx__ext")) else {}
cm = ChainMap({}, parent_map if isinstance(parent_map, ChainMap) else dict(parent_map))
# Determine the parent chain for this owner
parent_chain = ()
if self.__parent is not None and hasattr(self.__parent, "_AnnotationWalkerCtx__ext"):
parent_map = self.__parent._AnnotationWalkerCtx__ext.get(owner) # access private attr intentionally
if isinstance(parent_map, ChainMap):
# IMPORTANT: preserve the whole chain (not the ChainMap object as a single mapping)
parent_chain = tuple(parent_map.maps)
elif isinstance(parent_map, dict):
parent_chain = (parent_map,)
else:
parent_chain = ()
# Build a new ChainMap whose first map is this node's local overlay
cm = ChainMap({}, *parent_chain)
self.__ext[owner] = cm
return cm
# INTERNAL (walker only): set edge metadata on a child ctx
def _set_edge(self, *, role: str | None, token: Any | None, arg_index: int | None) -> None:
self.__edge_role = role
self.__edge_token = token
self.__arg_index = arg_index
@dataclass(frozen=True)
class TriggerResult:
@@ -169,6 +206,12 @@ class AnnotationTrigger:
def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
return None
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
return None
def end_trigger(self, ctx: Optional[AnnotationWalkerCtx]) -> None:
pass
class LAMSchemaValidation(AnnotationTrigger):
def init_trigger(self) -> None:
@@ -224,6 +267,462 @@ class LAMSchemaValidation(AnnotationTrigger):
raise IncompletelyAnnotatedField(f"None is only accepted with Union, to implement Optional[]")
return None
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
print(f"process_exit: {ctx.origin}")
return None
def end_trigger(self, ctx: AnnotationWalkerCtx):
pass
class DataValidation(AnnotationTrigger):
def __init__(self, value: Any) -> None:
self._root = value
def init_trigger(self) -> None:
self._seeded = False
def _bag(self, ctx: AnnotationWalkerCtx):
bag = ctx.ns(self)
if not self._seeded:
bag["value"] = self._root
bag["path"] = ()
self._seeded = True
return bag
def process_annotated(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_union(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_dict(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_tuple(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_list(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_set(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_unknown(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult:
b = self._bag(ctx)
def end_trigger(self, ctx: AnnotationWalkerCtx):
pass
class SchemaValidationError(TypeError):
def __init__(self, path: tuple[Any, ...], msg: str):
dotted = "".join(f"[{p}]" if isinstance(p, int) else (f".{p}" if path and i else str(p)) for i, p in enumerate(path)) or "<root>"
super().__init__(f"{dotted}: {msg}")
self.path = path
class HorizontalValidationTrigger(AnnotationTrigger):
def __init__(
self,
value: Any,
*,
strict_bool: bool = False,
collect_all: bool = True,
union_summary_max_per_branch: int = 3, # NEW
show_branch_types: bool = True, # NEW
):
self._root_value = value
self._strict_bool = strict_bool
self._collect_all = collect_all
self._union_summary_max = union_summary_max_per_branch
self._show_branch_types = show_branch_types
self._seeded = False
self._id_counter = 0
# ---------- utilities ----------
# ---------- utilities ----------
def _mk_cand(self, value: Any, path: tuple[Any, ...]):
cid = self._id_counter
self._id_counter += 1
return {"id": cid, "value": value, "path": path}
def _spawn_from(self, bag, parent_cand: dict, value: Any, path: tuple[Any, ...]):
"""
Create a child candidate inheriting union root id when inside a union branch.
"""
nc = self._mk_cand(value, path)
if "__union_branch_id" in bag:
# Ensure every descendant carries union-root id
nc["union_cid"] = parent_cand.get("union_cid", parent_cand["id"])
elif "union_cid" in parent_cand:
# Preserve if already present (e.g., nested under a branch)
nc["union_cid"] = parent_cand["union_cid"]
return nc
def _bag(self, ctx: AnnotationWalkerCtx):
bag = ctx.ns(self)
local = bag.maps[0]
parent_bag = ctx.parent.ns(self) if ctx.parent is not None else None
if not self._seeded and ctx.parent is None:
local["candidates"] = [self._mk_cand(self._root_value, ())]
local["errors"] = []
self._seeded = True
if "errors" not in local and parent_bag is not None:
local["errors"] = parent_bag.get("errors", [])
if ctx.parent is not None:
role = ctx.edge_role
if role == "elem":
local["candidates"] = list(parent_bag.get("pending_elem", []))
elif role == "key":
local["candidates"] = list(parent_bag.get("pending_key", []))
elif role == "val":
local["candidates"] = list(parent_bag.get("pending_val", []))
elif role == "arg":
local["candidates"] = list(parent_bag.get(f"pending_arg_{ctx.arg_index}", []))
elif role in ("branch", "annotated"):
local["candidates"] = list(parent_bag.get("candidates", []))
else:
local["candidates"] = list(parent_bag.get("candidates", []))
# If immediate child of a Union (branch), switch to branch-local errors
if ctx.parent is not None and ctx.parent.origin is UnionType and ctx.edge_role == "branch":
ubag = ctx.parent.ns(self)
branches = ubag.setdefault("union_branches_state", {})
bid = ctx.edge_token
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
local["errors"] = bstate["errors"]
local["__union_branch_id"] = bid
local["__union_state_ref"] = ubag
# NEW: propagate the union scope id to this branch
local["__union_scope_id"] = ubag.get("__union_scope_id")
# Stamp union_cid for all candidates at branch entry
stamped = []
for c in local.get("candidates", []):
c2 = dict(c)
c2.setdefault("union_cid", c2["id"])
stamped.append(c2)
local["candidates"] = stamped
return bag
def _mark_branch_fail(self, bag, cand: dict):
if "__union_branch_id" in bag and "__union_state_ref" in bag:
bid = bag["__union_branch_id"]
uref = bag["__union_state_ref"]
branches = uref.setdefault("union_branches_state", {})
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
root_id = cand.get("union_cid", cand["id"])
bstate["failed_ids"].add(root_id)
def _err(self, bag, cand, msg: str):
"""
Record an error. If we are inside a union branch, attach (cid, scope) so the
outer union can attribute it correctly to its branch/candidate.
"""
e = SchemaValidationError(cand["path"], msg)
# Find nearest union scope (walk ChainMap parents if needed)
def nearest_union_scope(b):
if "__union_scope_id" in b:
return b["__union_scope_id"]
# ChainMap.get() already searches parents; use get() with a sentinel
sentinel = object()
v = b.get("__union_scope_id", sentinel)
return None if v is sentinel else v
in_union_branch = ("__union_branch_id" in bag) and ("__union_state_ref" in bag)
if in_union_branch:
# union-root candidate id for attribution (propagated by spawn_from / branch entry)
root_id = cand.get("union_cid", cand["id"])
scope = nearest_union_scope(bag)
# Branch-local error bucket lives under union state
uref = bag["__union_state_ref"]
bid = bag["__union_branch_id"]
branches = uref.setdefault("union_branches_state", {})
bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []})
bstate["failed_ids"].add(root_id)
bstate["errors"].append({"cid": root_id, "scope": scope, "err": e})
return
# Outside any union branch: add to global errors (and optional fail-fast)
bag["errors"].append(e)
if not self._collect_all:
raise e
def _check_leaf(self, bag, cand: dict, T: type):
v = cand["value"]
if T is type(None):
if v is not None:
self._err(bag, cand, "expected None")
return
if self._strict_bool and T is bool and type(v) is not bool:
self._err(bag, cand, f"expected bool, got {type(v).__name__}")
return
if v is None or not isinstance(v, T):
got = "None" if v is None else type(v).__name__
self._err(bag, cand, f"expected {T.__name__}, got {got}")
# ---------- entry hooks (seed pending; validate leaves; no aggregation here) ----------
def process_allowed(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
T = ctx.origin
for cand in bag.get("candidates", []):
self._check_leaf(bag, cand, T)
def process_list(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
cands = bag.get("candidates", [])
if len(ctx.args) != 1:
for cand in cands:
self._err(bag, cand, "List[T] requires 1 argument")
return
pending = []
for cand in cands:
v = cand["value"]
if not isinstance(v, list):
self._err(bag, cand, f"expected list, got {type(v).__name__}")
continue
base = cand["path"]
for i in range(len(v)):
pending.append(self._spawn_from(bag, cand, v[i], base + (i,)))
bag["pending_elem"] = pending
def process_tuple(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
cands = bag.get("candidates", [])
vararg = len(ctx.args) == 2 and ctx.args[1] is Ellipsis
if vararg:
pending = []
for cand in cands:
v = cand["value"]
if not isinstance(v, tuple):
self._err(bag, cand, f"expected tuple, got {type(v).__name__}")
continue
base = cand["path"]
for i in range(len(v)):
pending.append(self._spawn_from(bag, cand, v[i], base + (i,)))
bag["pending_elem"] = pending
else:
arity = len(ctx.args)
slots: dict[int, list] = {}
for cand in cands:
v = cand["value"]
if not isinstance(v, tuple):
self._err(bag, cand, f"expected tuple, got {type(v).__name__}")
continue
if len(v) != arity:
self._err(bag, cand, f"expected tuple len {arity}, got {len(v)}")
continue
base = cand["path"]
for i, elem in enumerate(v):
slots.setdefault(i, []).append(self._spawn_from(bag, cand, elem, base + (i,)))
for i, group in slots.items():
bag[f"pending_arg_{i}"] = group
def process_set(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
cands = bag.get("candidates", [])
if len(ctx.args) != 1:
for cand in cands:
self._err(bag, cand, "Set[T]/FrozenSet[T] requires 1 argument")
return
pending = []
for cand in cands:
v = cand["value"]
if not isinstance(v, (set, frozenset)):
self._err(bag, cand, f"expected set/frozenset, got {type(v).__name__}")
continue
base = cand["path"]
for e in v:
pending.append(self._spawn_from(bag, cand, e, base + ("<elem>",)))
bag["pending_elem"] = pending
def process_dict(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
cands = bag.get("candidates", [])
if len(ctx.args) != 2:
for cand in cands:
self._err(bag, cand, "Dict[K,V] requires 2 arguments")
return
pkeys, pvals = [], []
for cand in cands:
v = cand["value"]
if not isinstance(v, dict):
self._err(bag, cand, f"expected dict, got {type(v).__name__}")
continue
base = cand["path"]
for k, val in v.items():
pkeys.append(self._spawn_from(bag, cand, k, base + ("<key>",)))
pvals.append(self._spawn_from(bag, cand, val, base + (k,)))
bag["pending_key"] = pkeys
bag["pending_val"] = pvals
def process_annotated(self, ctx: AnnotationWalkerCtx):
# No checks here; inner T will validate routed candidates.
self._bag(ctx)
def process_union(self, ctx: AnnotationWalkerCtx):
bag = self._bag(ctx)
# Candidates arriving at THIS union node (e.g., per list element when union is element type)
cands = bag.get("candidates", [])
# Track these exact candidate ids for this union scope
bag["union_candidate_ids"] = [c["id"] for c in cands]
# One entry per branch; each branch accumulates scoped errors keyed by union-root cid
bag.setdefault("union_branches_state", {}) # bid -> {"failed_ids": set(), "errors": []}
bag["union_branch_labels"] = [self._pretty_type(a) for a in ctx.args]
# Scope id distinguishes nested unions; only errors tagged with this scope count here
bag["__union_scope_id"] = id(ctx)
# ---------- exit (aggregation only) ----------
def process_exit(self, ctx: AnnotationWalkerCtx):
if ctx.origin is not UnionType:
return None
bag = ctx.ns(self)
root_errors = bag.get("errors", [])
branches = bag.get("union_branches_state", {})
cand_ids = bag.get("union_candidate_ids", [])
labels = bag.get("union_branch_labels", [])
scope_id = bag.get("__union_scope_id") # THIS union's scope
# Build: per_branch_errors[bid][cid] -> [SchemaValidationError, ...] filtered to THIS union scope
per_branch_errors: dict[int, dict[int, list[SchemaValidationError]]] = {}
for bid, state in branches.items():
berrs = state.get("errors", [])
m: dict[int, list[SchemaValidationError]] = {}
for item in berrs:
if isinstance(item, dict):
# Only errors tagged for THIS union scope
if item.get("scope") != scope_id:
continue
cid = item.get("cid")
err = item.get("err")
if cid is not None and isinstance(err, SchemaValidationError):
m.setdefault(cid, []).append(err)
elif isinstance(item, SchemaValidationError):
# Legacy/plain error: attribute to all failed ids of this branch
for cid in state.get("failed_ids", set()):
m.setdefault(cid, []).append(item)
# Dedupe identical (path, text)
for cid, lst in m.items():
seen = set()
uniq = []
for e in lst:
key = (e.path, str(e))
if key not in seen:
seen.add(key)
uniq.append(e)
m[cid] = uniq
per_branch_errors[bid] = m
# Decide per candidate of THIS union
for cid in cand_ids:
branch_ok = any(len(per_branch_errors.get(bid, {}).get(cid, [])) == 0 for bid in branches.keys())
if branch_ok:
continue # some branch matched, good
# Pretty summary for this failing candidate
# Get the path for this union's candidate; fall back to <root> if absent
cands_here = bag.get("candidates", [])
path = next((c["path"] for c in cands_here if c.get("id") == cid), ())
lines = ["no union branch matched; tried:"]
for bid, state in branches.items():
label = labels[bid] if bid < len(labels) else f"branch {bid}"
errs = per_branch_errors.get(bid, {}).get(cid, [])
if not errs:
lines.append(f" - {label}: mismatch")
continue
N = self._union_summary_max
shown = errs[:N]
lines.append(f" - {label}: {len(errs)} issue(s)")
for e in shown:
lines.append(f" - {e}")
if len(errs) > N:
lines.append(f" - (+{len(errs) - N} more)")
root_errors.append(SchemaValidationError(path, "\n".join(lines)))
# Also, if THIS union is nested inside an enclosing union branch, and ALL branches
# of THIS union failed for a candidate, bubble a single 'mismatch' up to that parent
# so the parent union can mark its branch as failed for that union-root cid.
if ("__union_state_ref" in bag) and ("__union_branch_id" in bag):
uref = bag["__union_state_ref"] # enclosing union's bag
up_bid = bag["__union_branch_id"] # which branch we're in
up_branches = uref.setdefault("union_branches_state", {})
up_state = up_branches.setdefault(up_bid, {"failed_ids": set(), "errors": []})
up_scope = uref.get("__union_scope_id")
# Map THIS union's local cid to the enclosing union's root cid via union_cid
local_cands = bag.get("candidates", [])
# Fallback map: if local_cands missing, assume 1:1
fallback_map = {cid: cid for cid in cand_ids}
map_local_to_outer = {}
for lc in local_cands:
local_id = lc.get("id")
outer_root = lc.get("union_cid", local_id)
map_local_to_outer[local_id] = outer_root
for cid in cand_ids:
branch_ok = any(len(per_branch_errors.get(bid, {}).get(cid, [])) == 0 for bid in branches.keys())
if branch_ok:
continue # do not bubble success
outer_cid = map_local_to_outer.get(cid, fallback_map[cid])
up_state["failed_ids"].add(outer_cid)
# lightweight marker so parent has at least one error on record for this cid
up_state["errors"].append({"cid": outer_cid, "scope": up_scope, "err": SchemaValidationError((), "mismatch")})
return None
def end_trigger(self, ctx: Optional[AnnotationWalkerCtx]):
if ctx is None:
return
errors = ctx.ns(self).get("errors", [])
if errors:
# Raise them together; swap for your preferred error carrier if needed
raise ExceptionGroup("schema validation failed", errors)
# --- pretty type for messages (best-effort) ---
def _pretty_type(self, t: Any) -> str:
origin = get_origin(t) or t
args = get_args(t)
try:
if origin is UnionType or origin is Union:
return " | ".join(self._pretty_type(a) for a in args)
if origin in (list, tuple, set, frozenset, dict, Annotated):
if origin is dict and len(args) == 2:
return f"dict[{self._pretty_type(args[0])}, {self._pretty_type(args[1])}]"
if origin in (list, set, frozenset) and len(args) == 1:
name = "list" if origin is list else ("set" if origin is set else "frozenset")
return f"{name}[{self._pretty_type(args[0])}]"
if origin is tuple:
if len(args) == 2 and args[1] is Ellipsis:
return f"tuple[{self._pretty_type(args[0])}, ...]"
return f"tuple[{', '.join(self._pretty_type(a) for a in args)}]"
if origin is Annotated and args:
return f"Annotated[{self._pretty_type(args[0])}, ...]"
if isinstance(origin, type):
return origin.__name__
return str(t)
except Exception:
return repr(t)
class AnnotationWalker:
DEFAULT_ALLOWED_TYPES = frozenset({str, int, float, complex, bool, bytes, NoneType})
@@ -235,7 +734,7 @@ class AnnotationWalker:
"Dict": Dict,
"Tuple": Tuple,
"Set": Set,
"FrozenSet": FrozenSet,
# "FrozenSet": FrozenSet,
"Annotated": Annotated,
# builtins:
"int": int,
@@ -248,7 +747,7 @@ class AnnotationWalker:
"list": list,
"dict": dict,
"set": set,
"frozenset": frozenset,
# "frozenset": frozenset,
"tuple": tuple,
}
)
@@ -285,9 +784,12 @@ class AnnotationWalker:
self.__ann = eval(ann, {"__builtins__": {}}, self._allowed_annotations)
def run(self) -> TriggerResult:
for trigger in self._triggers:
trigger.init_trigger()
return self._walk(self.__ann, None)
for t in self._triggers:
t.init_trigger()
self.root_ctx = None
self._walk(self.__ann, None)
for t in self._triggers:
t.end_trigger(self.root_ctx)
# --- Helpers ---
@@ -296,8 +798,8 @@ class AnnotationWalker:
def _apply_triggers(self, method: str, ctx: AnnotationWalkerCtx) -> TriggerResult:
final = TriggerResult.passthrough()
for trig in self._triggers:
res = getattr(trig, method)(ctx)
for t in self._triggers:
res = getattr(t, method)(ctx)
if not res:
continue
if res.restart_with is not None:
@@ -311,33 +813,55 @@ class AnnotationWalker:
)
return final
def _apply_exits(self, ctx: AnnotationWalkerCtx) -> Any | None:
"""
Run exits in order: type-specific (if implemented) then generic `process_exit`.
Only `replace_with` is honored at exit; last `replace_with` wins.
"""
final = None
for t in self._triggers:
res = t.process_exit(ctx)
if res and (res.replace_with is not None):
final = res.replace_with
return final
def _handle_with_triggers(
self,
trigger_name: str,
ctx: AnnotationWalkerCtx,
args_handler: Callable[[AnnotationWalkerCtx], Any] | None = None,
) -> Any:
"""Generic handler: run triggers, maybe recurse into args with a custom handler."""
# ENTER
res = self._apply_triggers(trigger_name, ctx)
if res.restart_with is not None:
return self._walk(res.restart_with, ctx.parent)
if res.replace_with is not None:
return res.replace_with
exit_val = self._apply_exits(ctx)
return exit_val if exit_val is not None else res.replace_with
node_value = None
if not res.skip_children:
if args_handler:
return args_handler(ctx)
return tuple(self._walk(a, ctx) for a in ctx.args)
return None
node_value = args_handler(ctx)
else:
# DEFAULT: descend once per schema arg; mark as positional arg
node_value = tuple(self._walk_child(a, ctx, arg_index=i, role="arg", token=i) for i, a in enumerate(ctx.args))
# EXIT
exit_val = self._apply_exits(ctx)
return exit_val if exit_val is not None else node_value
def _walk_args_tuple(self, ctx: AnnotationWalkerCtx):
# special Ellipsis case for Tuple
# Tuple[T, ...] (variadic): one schema child, mark role='elem'
if len(ctx.args) == 2 and ctx.args[1] is Ellipsis:
return (self._walk(ctx.args[0], ctx), Ellipsis)
return tuple(self._walk(a, ctx) for a in ctx.args)
return (self._walk_child(ctx.args[0], ctx, arg_index=0, role="elem", token=None), Ellipsis)
# Fixed tuple: each positional arg gets role='arg'
return tuple(self._walk_child(a, ctx, arg_index=i, role="arg", token=i) for i, a in enumerate(ctx.args))
# --- Dispatcher ---
def _walk(self, type_: Any, parent_ctx: Optional[AnnotationWalkerCtx]) -> Any:
# For logs only: show the calling layer
print(f"[{parent_ctx.layer if parent_ctx else 0}] walking through: {type_}")
origin = get_origin(type_) or type_
@@ -349,31 +873,102 @@ class AnnotationWalker:
raise RuntimeError("Annotation must be using type(s), not instances")
args = get_args(type_)
layer = 0 if parent_ctx is None else parent_ctx.layer + 1
ctx = self._new_ctx(origin, args, layer, parent_ctx)
# IMPORTANT:
# If caller (_walk_child) already constructed a child ctx with edge metadata,
# reuse that object as *the* ctx for this node instead of allocating a new one.
if isinstance(parent_ctx, AnnotationWalkerCtx) and parent_ctx.origin is origin and parent_ctx.args == args:
ctx = parent_ctx
else:
# Root or internal calls that didn't prebuild the ctx
layer = 0 if parent_ctx is None else parent_ctx.layer + 1
ctx = self._new_ctx(origin, args, layer, parent_ctx)
# Remember root ctx for end_trigger()
if ctx.parent is None:
self.root_ctx = ctx
print(origin)
match origin:
case typing.Annotated:
# inner type gets role='annotated'
return self._handle_with_triggers(
"process_annotated", ctx, args_handler=lambda c: self._walk(c.args[0], c) if c.args else None
"process_annotated",
ctx,
args_handler=lambda c: self._walk_child(c.args[0], c, arg_index=0, role="annotated", token=None) if c.args else None,
)
case types.UnionType:
return self._handle_with_triggers("process_union", ctx)
# branches get role='branch' and token=branch index
return self._handle_with_triggers(
"process_union",
ctx,
args_handler=lambda c: tuple(self._walk_child(a, c, arg_index=i, role="branch", token=i) for i, a in enumerate(c.args)),
)
case _ if issubclass(origin, dict):
return self._handle_with_triggers("process_dict", ctx)
# arg0=key (role='key'), arg1=value (role='val')
return self._handle_with_triggers(
"process_dict",
ctx,
args_handler=lambda c: (
self._walk_child(c.args[0], c, arg_index=0, role="key", token=None),
self._walk_child(c.args[1], c, arg_index=1, role="val", token=None),
),
)
case _ if issubclass(origin, tuple):
return self._handle_with_triggers("process_tuple", ctx, self._walk_args_tuple)
case _ if issubclass(origin, list):
return self._handle_with_triggers("process_list", ctx)
# single child T with role='elem'
return self._handle_with_triggers(
"process_list",
ctx,
args_handler=lambda c: self._walk_child(c.args[0], c, arg_index=0, role="elem", token=None) if c.args else None,
)
case _ if issubclass(origin, set):
return self._handle_with_triggers("process_set", ctx)
# single child T with role='elem'
return self._handle_with_triggers(
"process_set",
ctx,
args_handler=lambda c: self._walk_child(c.args[0], c, arg_index=0, role="elem", token=None) if c.args else None,
)
case _ if origin in self._allowed_types:
return self._handle_with_triggers("process_allowed", ctx)
case _:
res = self._apply_triggers("process_unknown", ctx)
if res.restart_with is not None:
return self._walk(res.restart_with, ctx.parent)
if res.replace_with is not None:
return res.replace_with
raise UnsupportedFieldType(f"Not supported Field: {ctx.origin}, " f"Supported list: {self._allowed_types}")
raise UnsupportedFieldType(f"Not supported Field: {ctx.origin}, Supported list: {self._allowed_types}")
def _walk_child(
self,
type_expr: Any,
parent_ctx: AnnotationWalkerCtx,
*,
arg_index: int,
role: str | None,
token: Any | None,
) -> Any:
origin = get_origin(type_expr) or type_expr
if origin is None:
origin = NoneType
if origin is Union:
origin = UnionType
if not isinstance(origin, type):
raise RuntimeError("Annotation must be using type(s), not instances")
args = get_args(type_expr)
child = self._new_ctx(origin, args, parent_ctx.layer + 1, parent_ctx)
# stamp routing metadata for triggers
child._set_edge(role=role, token=token, arg_index=arg_index)
# IMPORTANT: recurse with the CHILD as the parent_ctx for the next step,
# so the walker uses this child ctx (with edge metadata & incremented layer).
return self._walk(type_expr, child)

View File

@@ -22,11 +22,18 @@ testdir_path = Path(__file__).parent.resolve()
chdir(testdir_path.parent.resolve())
class ElementTest(unittest.TestCase):
class AnnotationsWalkerTest(unittest.TestCase):
def setUp(self):
print("\n->", unittest.TestCase.id(self))
def test_element_simple(self):
def test_validate(self):
ann = dict[int, list[int]] | dict[int, list[int | str]]
val = {1: [2], 2: ["a", [1]]}
res = dm.tools.AnnotationWalker(ann, (dm.tools.HorizontalValidationTrigger(val),))
res.run()
def test_simple(self):
print(isinstance(None, type(None)))
print("\n== From OBJs ==")
res = dm.tools.AnnotationWalker(Annotated[Optional[dict[int, list[str]]], "comment"], (dm.tools.LAMSchemaValidation(),))