diff --git a/src/dabmodel/tools.py b/src/dabmodel/tools.py index 5534c51..385e64c 100644 --- a/src/dabmodel/tools.py +++ b/src/dabmodel/tools.py @@ -73,953 +73,3 @@ def _resolve_annotation(ann): # Safe eval against a **whitelist** only return eval(ann, {"__builtins__": {}}, ALLOWED_ANNOTATIONS) # pylint: disable=eval-used return ann - - -class AnnotationWalkerCtx: - def __init__( - self, - origin: Any, - args: Any, - layer: int, - parent: Optional[Self] = None, - allowed_types: set[type, ...] = frozenset(), - allowed_annotations: dict[str, Any] = {}, - ): - self.__origin = origin - self.args = args - self.__layer = layer - self.__parent = parent - - self.__allowed_types: set[type, ...] = allowed_types - self.__allowed_annotations: dict[str, Any] = allowed_annotations - self.__ext: dict[Any, ChainMap] = {} # per-trigger namespaces (lazy) - - # NEW: edge routing metadata (walker sets; triggers read) - self.__edge_role: str | None = ( - None # 'elem' | 'key' | 'val' | 'branch' | 'annotated' | 'arg' - ) - self.__edge_token: Any | None = None # index/key/branch-id/etc. - self.__arg_index: int | None = None # which schema arg of the parent this child is - - @property - def origin(self) -> Any: - return self.__origin - - @property - def layer(self) -> int: - return self.__layer - - @property - def parent(self) -> Self: - return self.__parent - - @property - def allowed_types(self) -> FrozenSet[type]: - return self.__allowed_types - - @property - def allowed_annotations(self) -> Mapping[str, Any]: - return self.__allowed_annotations - - # NEW: read-only edge metadata for routing inside triggers - @property - def edge_role(self) -> str | None: - return self.__edge_role - - @property - def edge_token(self) -> Any | None: - return self.__edge_token - - @property - def arg_index(self) -> int | None: - return self.__arg_index - - def ns(self, owner: Any) -> ChainMap: - """ - A per-trigger overlay namespace that inherits from parent ctx. - Use as: bag = ctx.ns(self); bag['whatever'] = ... - Lookups fall back to parent's bag automatically. - """ - if owner in self.__ext: - return self.__ext[owner] - - # Determine the parent chain for this owner - parent_chain = () - if self.__parent is not None and hasattr(self.__parent, "_AnnotationWalkerCtx__ext"): - parent_map = self.__parent._AnnotationWalkerCtx__ext.get( - owner - ) # access private attr intentionally - if isinstance(parent_map, ChainMap): - # IMPORTANT: preserve the whole chain (not the ChainMap object as a single mapping) - parent_chain = tuple(parent_map.maps) - elif isinstance(parent_map, dict): - parent_chain = (parent_map,) - else: - parent_chain = () - - # Build a new ChainMap whose first map is this node's local overlay - cm = ChainMap({}, *parent_chain) - self.__ext[owner] = cm - return cm - - # INTERNAL (walker only): set edge metadata on a child ctx - def _set_edge(self, *, role: str | None, token: Any | None, arg_index: int | None) -> None: - self.__edge_role = role - self.__edge_token = token - self.__arg_index = arg_index - - -@dataclass(frozen=True) -class TriggerResult: - # If provided, children won't be walked and this value is returned. - replace_with: Any | None = None - # If true, skip walking children but don't replace current node value. - skip_children: bool = False - # If provided, walker will restart processing with the given value - restart_with: Any | None = None # NEW - - @staticmethod - def passthrough() -> Self: - return TriggerResult() - - @staticmethod - def replace(value: Any) -> Self: - return TriggerResult(replace_with=value, skip_children=True) - - @staticmethod - def skip() -> Self: - return TriggerResult(skip_children=True) - - @staticmethod - def restart(value: Any) -> Self: - print("Doo!") - return TriggerResult(restart_with=value) - - -class AnnotationTrigger: - - def init_trigger(self) -> None: - pass - - def process_annotated(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - return None - - def process_union(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - return None - - def process_dict(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - return None - - def process_tuple(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - return None - - def process_list(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - return None - - def process_set(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - return None - - def process_unknown(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - return None - - def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - return None - - def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - return None - - def end_trigger(self, ctx: Optional[AnnotationWalkerCtx]) -> None: - pass - - -class LAMSchemaValidation(AnnotationTrigger): - def init_trigger(self) -> None: - print(f"Initializing {self.__class__.__name__}") - - def process_annotated(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - print("process_annotated") - print(ctx.origin) - print(ctx.args) - if len(ctx.args) != 2: - raise UnsupportedFieldType("Annotated[T,x] requires 2 parameters") - if ctx.parent is not None: - raise UnsupportedFieldType("Annotated[T,x] is only supported as parent annotation") - return None - - def process_union(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - print("process_union") - print(ctx.args) - if (len(ctx.args) != 2) or (type(None) not in list(ctx.args)): - raise UnsupportedFieldType( - "Union[] is only supported to implement Optional[] (takes 2 parameters, including None)" - ) - return None - - def process_dict(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - print("process_dict") - if len(ctx.args) != 2: - raise IncompletelyAnnotatedField( - f"Dict Annotation requires 2 inner definitions: {ctx.origin}" - ) - if not ctx.args[0] in ctx.allowed_types: - raise IncompletelyAnnotatedField(f"Dict Key must be simple builtin: {ctx.origin}") - return None - - def process_tuple(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - print("process_tuple") - if len(ctx.args) == 0: - raise IncompletelyAnnotatedField(f"Annotation requires inner definition: {ctx.origin}") - return None - - def process_list(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - print("process_list") - if len(ctx.args) == 0: - raise IncompletelyAnnotatedField(f"Annotation requires inner definition: {ctx.origin}") - return None - - def process_set(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - print("process_set") - if len(ctx.args) == 0: - raise IncompletelyAnnotatedField(f"Annotation requires inner definition: {ctx.origin}") - return None - - def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - print("process_allowed") - if ctx.origin is type(None) or ctx.origin is None: - if ctx.parent is None or not ( - ctx.parent.origin is Union or ctx.parent.origin is UnionType - ): - raise IncompletelyAnnotatedField( - f"None is only accepted with Union, to implement Optional[]" - ) - return None - - def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - print(f"process_exit: {ctx.origin}") - return None - - def end_trigger(self, ctx: AnnotationWalkerCtx): - pass - - -class DataValidation(AnnotationTrigger): - def __init__(self, value: Any) -> None: - self._root = value - - def init_trigger(self) -> None: - self._seeded = False - - def _bag(self, ctx: AnnotationWalkerCtx): - bag = ctx.ns(self) - if not self._seeded: - bag["value"] = self._root - bag["path"] = () - self._seeded = True - return bag - - def process_annotated(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - b = self._bag(ctx) - - def process_union(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - b = self._bag(ctx) - - def process_dict(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - b = self._bag(ctx) - - def process_tuple(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - b = self._bag(ctx) - - def process_list(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - b = self._bag(ctx) - - def process_set(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - b = self._bag(ctx) - - def process_unknown(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - b = self._bag(ctx) - - def process_allowed(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - b = self._bag(ctx) - - def process_exit(self, ctx: AnnotationWalkerCtx) -> None | TriggerResult: - b = self._bag(ctx) - - def end_trigger(self, ctx: AnnotationWalkerCtx): - pass - - -class SchemaValidationError(TypeError): - def __init__(self, path: tuple[Any, ...], msg: str): - dotted = ( - "".join( - f"[{p}]" if isinstance(p, int) else (f".{p}" if path and i else str(p)) - for i, p in enumerate(path) - ) - or "" - ) - super().__init__(f"{dotted}: {msg}") - self.path = path - - -class HorizontalValidationTrigger(AnnotationTrigger): - def __init__( - self, - value: Any, - *, - strict_bool: bool = False, - collect_all: bool = True, - union_summary_max_per_branch: int = 3, # NEW - show_branch_types: bool = True, # NEW - ): - self._root_value = value - self._strict_bool = strict_bool - self._collect_all = collect_all - self._union_summary_max = union_summary_max_per_branch - self._show_branch_types = show_branch_types - self._seeded = False - self._id_counter = 0 - - # ---------- utilities ---------- - - # ---------- utilities ---------- - - def _mk_cand(self, value: Any, path: tuple[Any, ...]): - cid = self._id_counter - self._id_counter += 1 - return {"id": cid, "value": value, "path": path} - - def _spawn_from(self, bag, parent_cand: dict, value: Any, path: tuple[Any, ...]): - nc = self._mk_cand(value, path) - # Inherit union_roots map from parent (create if missing) - roots = dict(parent_cand.get("union_roots", {})) - nc["union_roots"] = roots - - # If we are in a union branch, ensure current scope is stamped - if "__union_branch_id" in bag and "__union_state_ref" in bag: - cur_scope = bag.get("__union_scope_id") - if cur_scope is not None: - roots.setdefault( - cur_scope, parent_cand.get("union_roots", {}).get(cur_scope, parent_cand["id"]) - ) - - return nc - - def _bag(self, ctx: AnnotationWalkerCtx): - bag = ctx.ns(self) - local = bag.maps[0] - parent_bag = ctx.parent.ns(self) if ctx.parent is not None else None - - # (existing seeding & routing)... - - # ------------- NEW: inherit enclosing union flags to all descendants ------------- - if parent_bag is not None: - for k in ("__union_state_ref", "__union_branch_id", "__union_scope_id"): - if (k in parent_bag) and (k not in local): - local[k] = parent_bag[k] - - # ------------- If immediate child of a Union (branch) ------------- - if ctx.parent is not None and ctx.parent.origin is UnionType and ctx.edge_role == "branch": - ubag = ctx.parent.ns(self) - branches = ubag.setdefault("union_branches_state", {}) - bid = ctx.edge_token - bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []}) - local["errors"] = bstate["errors"] - local["__union_branch_id"] = bid - local["__union_state_ref"] = ubag - # keep current union scope id (from parent union node) - local["__union_scope_id"] = ubag.get("__union_scope_id") - - # STAMP union_roots for all candidates at branch entry - stamped = [] - cur_scope = local.get("__union_scope_id") - for c in local.get("candidates", []): - c2 = dict(c) - roots = dict(c2.get("union_roots", {})) - if cur_scope is not None: - roots.setdefault(cur_scope, roots.get(cur_scope, c2["id"])) - c2["union_roots"] = roots - stamped.append(c2) - local["candidates"] = stamped - - return bag - - def _mark_branch_fail(self, bag, cand: dict): - if "__union_branch_id" in bag and "__union_state_ref" in bag: - bid = bag["__union_branch_id"] - uref = bag["__union_state_ref"] - branches = uref.setdefault("union_branches_state", {}) - bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []}) - root_id = cand.get("union_cid", cand["id"]) - bstate["failed_ids"].add(root_id) - - def _err(self, bag, cand, msg: str): - """ - Record an error. If we are inside a union branch, attach (cid, scope) so the - outer union can attribute it correctly to its branch/candidate. - """ - e = SchemaValidationError(cand["path"], msg) - - # Find nearest union scope (walk ChainMap parents if needed) - def nearest_union_scope(b): - if "__union_scope_id" in b: - return b["__union_scope_id"] - # ChainMap.get() already searches parents; use get() with a sentinel - sentinel = object() - v = b.get("__union_scope_id", sentinel) - return None if v is sentinel else v - - in_union_branch = ("__union_branch_id" in bag) and ("__union_state_ref" in bag) - if in_union_branch: - # union-root candidate id for attribution (propagated by spawn_from / branch entry) - root_id = cand.get("union_cid", cand["id"]) - scope = nearest_union_scope(bag) - # Branch-local error bucket lives under union state - uref = bag["__union_state_ref"] - bid = bag["__union_branch_id"] - branches = uref.setdefault("union_branches_state", {}) - bstate = branches.setdefault(bid, {"failed_ids": set(), "errors": []}) - bstate["failed_ids"].add(root_id) - bstate["errors"].append({"cid": root_id, "scope": scope, "err": e}) - return - - # Outside any union branch: add to global errors (and optional fail-fast) - bag["errors"].append(e) - if not self._collect_all: - raise e - - def _check_leaf(self, bag, cand: dict, T: type): - v = cand["value"] - if T is type(None): - if v is not None: - self._err(bag, cand, "expected None") - return - if self._strict_bool and T is bool and type(v) is not bool: - self._err(bag, cand, f"expected bool, got {type(v).__name__}") - return - if v is None or not isinstance(v, T): - got = "None" if v is None else type(v).__name__ - self._err(bag, cand, f"expected {T.__name__}, got {got}") - - # ---------- entry hooks (seed pending; validate leaves; no aggregation here) ---------- - - def process_allowed(self, ctx: AnnotationWalkerCtx): - bag = self._bag(ctx) - T = ctx.origin - for cand in bag.get("candidates", []): - self._check_leaf(bag, cand, T) - - def process_list(self, ctx: AnnotationWalkerCtx): - bag = self._bag(ctx) - cands = bag.get("candidates", []) - if len(ctx.args) != 1: - for cand in cands: - self._err(bag, cand, "List[T] requires 1 argument") - return - pending = [] - for cand in cands: - v = cand["value"] - if not isinstance(v, list): - self._err(bag, cand, f"expected list, got {type(v).__name__}") - continue - base = cand["path"] - for i in range(len(v)): - pending.append(self._spawn_from(bag, cand, v[i], base + (i,))) - bag["pending_elem"] = pending - - def process_tuple(self, ctx: AnnotationWalkerCtx): - bag = self._bag(ctx) - cands = bag.get("candidates", []) - vararg = len(ctx.args) == 2 and ctx.args[1] is Ellipsis - if vararg: - pending = [] - for cand in cands: - v = cand["value"] - if not isinstance(v, tuple): - self._err(bag, cand, f"expected tuple, got {type(v).__name__}") - continue - base = cand["path"] - for i in range(len(v)): - pending.append(self._spawn_from(bag, cand, v[i], base + (i,))) - bag["pending_elem"] = pending - else: - arity = len(ctx.args) - slots: dict[int, list] = {} - for cand in cands: - v = cand["value"] - if not isinstance(v, tuple): - self._err(bag, cand, f"expected tuple, got {type(v).__name__}") - continue - if len(v) != arity: - self._err(bag, cand, f"expected tuple len {arity}, got {len(v)}") - continue - base = cand["path"] - for i, elem in enumerate(v): - slots.setdefault(i, []).append(self._spawn_from(bag, cand, elem, base + (i,))) - for i, group in slots.items(): - bag[f"pending_arg_{i}"] = group - - def process_set(self, ctx: AnnotationWalkerCtx): - bag = self._bag(ctx) - cands = bag.get("candidates", []) - if len(ctx.args) != 1: - for cand in cands: - self._err(bag, cand, "Set[T]/FrozenSet[T] requires 1 argument") - return - pending = [] - for cand in cands: - v = cand["value"] - if not isinstance(v, (set, frozenset)): - self._err(bag, cand, f"expected set/frozenset, got {type(v).__name__}") - continue - base = cand["path"] - for e in v: - pending.append(self._spawn_from(bag, cand, e, base + ("",))) - bag["pending_elem"] = pending - - def process_dict(self, ctx: AnnotationWalkerCtx): - bag = self._bag(ctx) - cands = bag.get("candidates", []) - if len(ctx.args) != 2: - for cand in cands: - self._err(bag, cand, "Dict[K,V] requires 2 arguments") - return - pkeys, pvals = [], [] - for cand in cands: - v = cand["value"] - if not isinstance(v, dict): - self._err(bag, cand, f"expected dict, got {type(v).__name__}") - continue - base = cand["path"] - for k, val in v.items(): - pkeys.append(self._spawn_from(bag, cand, k, base + ("",))) - pvals.append(self._spawn_from(bag, cand, val, base + (k,))) - bag["pending_key"] = pkeys - bag["pending_val"] = pvals - - def process_annotated(self, ctx: AnnotationWalkerCtx): - # No checks here; inner T will validate routed candidates. - self._bag(ctx) - - def process_union(self, ctx: AnnotationWalkerCtx): - bag = self._bag(ctx) - cands = bag.get("candidates", []) - # Preserve enclosing union scope id (if any) so inner unions can bubble to it - if "__outer_union_scope_id" not in bag and "__union_scope_id" in bag: - bag["__outer_union_scope_id"] = bag["__union_scope_id"] - - bag["union_candidate_ids"] = [c["id"] for c in cands] - bag.setdefault("union_branches_state", {}) # bid -> {"failed_ids": set(), "errors": []} - bag["union_branch_labels"] = [self._pretty_type(a) for a in ctx.args] - # Set THIS union's scope id (distinct from the preserved outer scope) - bag["__union_scope_id"] = id(ctx) - - # ---------- exit (aggregation only) ---------- - - def process_exit(self, ctx: AnnotationWalkerCtx): - if ctx.origin is not UnionType: - return None - - bag = ctx.ns(self) - root_errors = bag.get("errors", []) - branches = bag.get("union_branches_state", {}) - cand_ids = bag.get("union_candidate_ids", []) - labels = bag.get("union_branch_labels", []) - scope_id = bag.get("__union_scope_id") - outer_scope_id = bag.get("__outer_union_scope_id") # preserved from parent (if any) - - # Build per-branch error map for THIS union scope - per_branch_errors: dict[int, dict[int, list[SchemaValidationError]]] = {} - for bid, state in branches.items(): - berrs = state.get("errors", []) - m: dict[int, list[SchemaValidationError]] = {} - for item in berrs: - if isinstance(item, dict): - if item.get("scope") != scope_id: - continue - cid = item.get("cid") - err = item.get("err") - if cid is not None and isinstance(err, SchemaValidationError): - m.setdefault(cid, []).append(err) - elif isinstance(item, SchemaValidationError): - for cid in state.get("failed_ids", set()): - m.setdefault(cid, []).append(item) - # de-dupe per cid - for cid, lst in m.items(): - seen = set() - uniq = [] - for e in lst: - key = (e.path, str(e)) - if key not in seen: - seen.add(key) - uniq.append(e) - m[cid] = uniq - per_branch_errors[bid] = m - - # Decision: pass if ANY branch has zero scoped errors - for cid in cand_ids: - branch_pass = any( - len(per_branch_errors.get(bid, {}).get(cid, [])) == 0 for bid in branches.keys() - ) - - if branch_pass: - continue - - # --- Bubble to enclosing union if we are inside one (i.e., this union is in a branch) --- - if ("__union_state_ref" in bag) and ("__union_branch_id" in bag): - uref = bag["__union_state_ref"] # enclosing union node bag - up_bid = bag["__union_branch_id"] # which enclosing branch we are in - up_branches = uref.setdefault("union_branches_state", {}) - up_state = up_branches.setdefault(up_bid, {"failed_ids": set(), "errors": []}) - up_scope = uref.get("__union_scope_id") - - # Find local candidate to recover outer root via union_roots - cands_here = bag.get("candidates", []) - lc = next((c for c in cands_here if c["id"] == cid), None) - if lc is not None and outer_scope_id is not None: - outer_root = lc.get("union_roots", {}).get(outer_scope_id, lc["id"]) - else: - outer_root = cid - - up_state["failed_ids"].add(outer_root) - up_state["errors"].append( - { - "cid": outer_root, - "scope": up_scope, - "err": SchemaValidationError((), "mismatch"), - } - ) - - # Build readable per-branch summary at THIS union site - cands_here = bag.get("candidates", []) - path = next((c["path"] for c in cands_here if c["id"] == cid), ()) - lines = ["no union branch matched; tried:"] - for bid, state in branches.items(): - label = labels[bid] if bid < len(labels) else f"branch {bid}" - these = per_branch_errors.get(bid, {}).get(cid, []) - if not these: - lines.append(f" - {label}: mismatch") - continue - N = getattr(self, "_union_summary_max", 3) - shown = these[:N] - lines.append(f" - {label}: {len(these)} issue(s)") - for e in shown: - lines.append(f" - {e}") - if len(these) > N: - lines.append(f" - (+{len(these) - N} more)") - root_errors.append(SchemaValidationError(path, "\n".join(lines))) - return None - - def end_trigger(self, ctx: Optional[AnnotationWalkerCtx]): - if ctx is None: - return - errors = ctx.ns(self).get("errors", []) - if errors: - # Raise them together; swap for your preferred error carrier if needed - raise ExceptionGroup("schema validation failed", errors) - - # --- pretty type for messages (best-effort) --- - def _pretty_type(self, t: Any) -> str: - origin = get_origin(t) or t - args = get_args(t) - try: - if origin is UnionType or origin is Union: - return " | ".join(self._pretty_type(a) for a in args) - if origin in (list, tuple, set, frozenset, dict, Annotated): - if origin is dict and len(args) == 2: - return f"dict[{self._pretty_type(args[0])}, {self._pretty_type(args[1])}]" - if origin in (list, set, frozenset) and len(args) == 1: - name = "list" if origin is list else ("set" if origin is set else "frozenset") - return f"{name}[{self._pretty_type(args[0])}]" - if origin is tuple: - if len(args) == 2 and args[1] is Ellipsis: - return f"tuple[{self._pretty_type(args[0])}, ...]" - return f"tuple[{', '.join(self._pretty_type(a) for a in args)}]" - if origin is Annotated and args: - return f"Annotated[{self._pretty_type(args[0])}, ...]" - if isinstance(origin, type): - return origin.__name__ - return str(t) - except Exception: - return repr(t) - - -class AnnotationWalker: - DEFAULT_ALLOWED_TYPES = frozenset({str, int, float, complex, bool, bytes, NoneType}) - DEFAULT_ALLOWED_ANNOTATIONS: dict[str, Any] = frozendict( - { - "Union": Union, - "Optional": Optional, - "List": List, - "Dict": Dict, - "Tuple": Tuple, - "Set": Set, - # "FrozenSet": FrozenSet, - "Annotated": Annotated, - # builtins: - "int": int, - "str": str, - "float": float, - "bool": bool, - "complex": complex, - "bytes": bytes, - "None": type(None), - "list": list, - "dict": dict, - "set": set, - # "frozenset": frozenset, - "tuple": tuple, - } - ) - - def __init__(self, ann: Any, triggers: tuple[AnnotationTrigger, ...], **kwargs): - if not triggers: - raise RuntimeError("AnnotationWalker requires trigger(s)") - - # Normalize triggers into instances - insts: list[AnnotationTrigger] = [] - for t in triggers if isinstance(triggers, tuple) else (triggers,): - if isinstance(t, AnnotationTrigger): - insts.append(t) - elif isinstance(t, type) and issubclass(t, AnnotationTrigger): - insts.append(t()) - else: - raise RuntimeError(f"Unsupported trigger: {t}") - self._triggers = tuple(insts) - - # Allowed types / annotations - atypes = set(type(self).DEFAULT_ALLOWED_TYPES) - if "ex_allowed_types" in kwargs: - atypes.update(kwargs["ex_allowed_types"]) - self._allowed_types = frozenset(atypes) - - annots = dict(type(self).DEFAULT_ALLOWED_ANNOTATIONS) - if "ex_allowed_annotations" in kwargs: - annots.update(kwargs["ex_allowed_annotations"]) - self._allowed_annotations = frozendict(annots) - - # Annotation can be string - self.__ann = ann - if isinstance(ann, str): - self.__ann = eval(ann, {"__builtins__": {}}, self._allowed_annotations) - - def run(self) -> TriggerResult: - for t in self._triggers: - t.init_trigger() - self.root_ctx = None - self._walk(self.__ann, None) - for t in self._triggers: - t.end_trigger(self.root_ctx) - - # --- Helpers --- - - def _new_ctx(self, origin, args, layer, parent): - return AnnotationWalkerCtx( - origin, args, layer, parent, self._allowed_types, self._allowed_annotations - ) - - def _apply_triggers(self, method: str, ctx: AnnotationWalkerCtx) -> TriggerResult: - final = TriggerResult.passthrough() - for t in self._triggers: - res = getattr(t, method)(ctx) - if not res: - continue - if res.restart_with is not None: - return res # short-circuit on restart - if res.replace_with is not None: - final = TriggerResult.replace(res.replace_with) - if res.skip_children: - final = TriggerResult( - replace_with=final.replace_with, - skip_children=True, - ) - return final - - def _apply_exits(self, ctx: AnnotationWalkerCtx) -> Any | None: - """ - Run exits in order: type-specific (if implemented) then generic `process_exit`. - Only `replace_with` is honored at exit; last `replace_with` wins. - """ - final = None - for t in self._triggers: - res = t.process_exit(ctx) - if res and (res.replace_with is not None): - final = res.replace_with - return final - - def _handle_with_triggers( - self, - trigger_name: str, - ctx: AnnotationWalkerCtx, - args_handler: Callable[[AnnotationWalkerCtx], Any] | None = None, - ) -> Any: - # ENTER - res = self._apply_triggers(trigger_name, ctx) - if res.restart_with is not None: - return self._walk(res.restart_with, ctx.parent) - if res.replace_with is not None: - exit_val = self._apply_exits(ctx) - return exit_val if exit_val is not None else res.replace_with - - node_value = None - if not res.skip_children: - if args_handler: - node_value = args_handler(ctx) - else: - # DEFAULT: descend once per schema arg; mark as positional arg - node_value = tuple( - self._walk_child(a, ctx, arg_index=i, role="arg", token=i) - for i, a in enumerate(ctx.args) - ) - - # EXIT - exit_val = self._apply_exits(ctx) - return exit_val if exit_val is not None else node_value - - def _walk_args_tuple(self, ctx: AnnotationWalkerCtx): - # Tuple[T, ...] (variadic): one schema child, mark role='elem' - if len(ctx.args) == 2 and ctx.args[1] is Ellipsis: - return ( - self._walk_child(ctx.args[0], ctx, arg_index=0, role="elem", token=None), - Ellipsis, - ) - # Fixed tuple: each positional arg gets role='arg' - return tuple( - self._walk_child(a, ctx, arg_index=i, role="arg", token=i) - for i, a in enumerate(ctx.args) - ) - - # --- Dispatcher --- - - def _walk(self, type_: Any, parent_ctx: Optional[AnnotationWalkerCtx]) -> Any: - # For logs only: show the calling layer - print(f"[{parent_ctx.layer if parent_ctx else 0}] walking through: {type_}") - - origin = get_origin(type_) or type_ - if origin is None: - origin = NoneType - if origin is Union: - origin = UnionType - if not isinstance(origin, type): - raise RuntimeError("Annotation must be using type(s), not instances") - - args = get_args(type_) - - # IMPORTANT: - # If caller (_walk_child) already constructed a child ctx with edge metadata, - # reuse that object as *the* ctx for this node instead of allocating a new one. - if ( - isinstance(parent_ctx, AnnotationWalkerCtx) - and parent_ctx.origin is origin - and parent_ctx.args == args - ): - ctx = parent_ctx - else: - # Root or internal calls that didn't prebuild the ctx - layer = 0 if parent_ctx is None else parent_ctx.layer + 1 - ctx = self._new_ctx(origin, args, layer, parent_ctx) - - # Remember root ctx for end_trigger() - if ctx.parent is None: - self.root_ctx = ctx - - print(origin) - match origin: - case typing.Annotated: - # inner type gets role='annotated' - return self._handle_with_triggers( - "process_annotated", - ctx, - args_handler=lambda c: ( - self._walk_child(c.args[0], c, arg_index=0, role="annotated", token=None) - if c.args - else None - ), - ) - - case types.UnionType: - # branches get role='branch' and token=branch index - return self._handle_with_triggers( - "process_union", - ctx, - args_handler=lambda c: tuple( - self._walk_child(a, c, arg_index=i, role="branch", token=i) - for i, a in enumerate(c.args) - ), - ) - - case _ if issubclass(origin, dict): - # arg0=key (role='key'), arg1=value (role='val') - return self._handle_with_triggers( - "process_dict", - ctx, - args_handler=lambda c: ( - self._walk_child(c.args[0], c, arg_index=0, role="key", token=None), - self._walk_child(c.args[1], c, arg_index=1, role="val", token=None), - ), - ) - - case _ if issubclass(origin, tuple): - return self._handle_with_triggers("process_tuple", ctx, self._walk_args_tuple) - - case _ if issubclass(origin, list): - # single child T with role='elem' - return self._handle_with_triggers( - "process_list", - ctx, - args_handler=lambda c: ( - self._walk_child(c.args[0], c, arg_index=0, role="elem", token=None) - if c.args - else None - ), - ) - - case _ if issubclass(origin, set): - # single child T with role='elem' - return self._handle_with_triggers( - "process_set", - ctx, - args_handler=lambda c: ( - self._walk_child(c.args[0], c, arg_index=0, role="elem", token=None) - if c.args - else None - ), - ) - - case _ if origin in self._allowed_types: - return self._handle_with_triggers("process_allowed", ctx) - - case _: - res = self._apply_triggers("process_unknown", ctx) - if res.restart_with is not None: - return self._walk(res.restart_with, ctx.parent) - if res.replace_with is not None: - return res.replace_with - raise UnsupportedFieldType( - f"Not supported Field: {ctx.origin}, Supported list: {self._allowed_types}" - ) - - def _walk_child( - self, - type_expr: Any, - parent_ctx: AnnotationWalkerCtx, - *, - arg_index: int, - role: str | None, - token: Any | None, - ) -> Any: - origin = get_origin(type_expr) or type_expr - if origin is None: - origin = NoneType - if origin is Union: - origin = UnionType - if not isinstance(origin, type): - raise RuntimeError("Annotation must be using type(s), not instances") - - args = get_args(type_expr) - child = self._new_ctx(origin, args, parent_ctx.layer + 1, parent_ctx) - # stamp routing metadata for triggers - child._set_edge(role=role, token=token, arg_index=arg_index) - - # IMPORTANT: recurse with the CHILD as the parent_ctx for the next step, - # so the walker uses this child ctx (with edge metadata & incremented layer). - return self._walk(type_expr, child)