From 62fedab55206eb81843b7d9bb92e711329ae61d8 Mon Sep 17 00:00:00 2001 From: cclecle Date: Tue, 31 Oct 2023 23:58:51 +0000 Subject: [PATCH 01/20] first commit --- .project | 2 +- .settings/org.eclipse.core.resources.prefs | 1 + README.md | 53 +- pyproject.toml | 9 +- src/pyrestresource/__init__.py | 59 +- src/pyrestresource/__metadata__.py | 42 ++ src/pyrestresource/data/.keep | 0 src/pyrestresource/data/__init__.py | 7 - src/pyrestresource/helpers.py | 17 + src/pyrestresource/rest_request.py | 212 ++++++ src/pyrestresource/rest_request_opt.py | 73 ++ src/pyrestresource/rest_resource.py | 281 ++++++++ src/pyrestresource/rest_resource_handler.py | 631 ++++++++++++++++++ .../rest_resource_handler_walker.py | 84 +++ src/pyrestresource/rest_resource_plugin.py | 170 +++++ src/pyrestresource/rest_resource_walker.py | 291 ++++++++ src/pyrestresource/rest_types.py | 105 +++ src/pyrestresource/test_module.py | 43 -- test/__init__.py | 2 +- test/test_rest_login.py | 80 +++ test/test_rest_resource.py | 557 ++++++++++++++++ test/test_rest_resource_plugins.py | 216 ++++++ test/test_rest_resource_walker.py | 163 +++++ test/test_rest_resource_walker_tree.py | 150 +++++ test/test_rest_webserver.py | 382 +++++++++++ test/test_test_module.py | 35 - 26 files changed, 3517 insertions(+), 148 deletions(-) create mode 100644 src/pyrestresource/__metadata__.py delete mode 100644 src/pyrestresource/data/.keep delete mode 100644 src/pyrestresource/data/__init__.py create mode 100644 src/pyrestresource/helpers.py create mode 100644 src/pyrestresource/rest_request.py create mode 100644 src/pyrestresource/rest_request_opt.py create mode 100644 src/pyrestresource/rest_resource.py create mode 100644 src/pyrestresource/rest_resource_handler.py create mode 100644 src/pyrestresource/rest_resource_handler_walker.py create mode 100644 src/pyrestresource/rest_resource_plugin.py create mode 100644 src/pyrestresource/rest_resource_walker.py create mode 100644 src/pyrestresource/rest_types.py delete mode 100644 src/pyrestresource/test_module.py create mode 100644 test/test_rest_login.py create mode 100644 test/test_rest_resource.py create mode 100644 test/test_rest_resource_plugins.py create mode 100644 test/test_rest_resource_walker.py create mode 100644 test/test_rest_resource_walker_tree.py create mode 100644 test/test_rest_webserver.py delete mode 100644 test/test_test_module.py diff --git a/.project b/.project index dc36649..f4b48b7 100644 --- a/.project +++ b/.project @@ -1,6 +1,6 @@ - {{project_name}} + pyrestresource diff --git a/.settings/org.eclipse.core.resources.prefs b/.settings/org.eclipse.core.resources.prefs index 99f26c0..f0456ba 100644 --- a/.settings/org.eclipse.core.resources.prefs +++ b/.settings/org.eclipse.core.resources.prefs @@ -1,2 +1,3 @@ eclipse.preferences.version=1 +encoding//src/pyrestresource/__init__.py=utf-8 encoding/=UTF-8 diff --git a/README.md b/README.md index 82c6bd2..36ee062 100644 --- a/README.md +++ b/README.md @@ -8,46 +8,25 @@ ![](docs-static/Library.jpg) -# Python project template +# pyrestresource -A nice template to start blank python projets. - -This template automate a lot of handy things and allow CI/CD automatic releases generation. +A RESTful API library built on top of pydantic & uvicorn to make service API from a data model. -It is also collectings data to feed Jenkins build. +/!\ early in-progress project for internal use ATM. -Checkout [Latest Documentation](https://chacha.ddns.net/mkdocs-web/chacha/{{repository}}/{{branch}}/latest/). +Feel free to contribute. -## Features +Features: +- use annotation +- support containers (dict) +- support plugins (for hook and biding) +- user authentification (WIP) +- ACL (WIP) +- python internal model instance (with possible serialization/auto-save on-disk) +- daemon mode -### Generic pipeline skeleton: - - Prepare - - GetCode - - BuildPackage - - Install - - CheckCode - - PlotMetrics - - RunUnitTests - - GenDOC - - PostRelease - -### CI/CD Environment - - Jenkins - - Gitea (with patch for dynamic Readme variables: https://chacha.ddns.net/gitea/chacha/GiteaMarkupVariable) - - Docker - - MkDocsWeb - -### CI/CD Helper libs - - VirtualEnv - - Changelog generation based on commits - - copier - - pylint + pylint_json2html - - mypy - - unittest + xmlrunner + junitparser + junit2htmlreport - - mkdocs +Limitations: +- no nested reads / writes +- weak unitest (atm) -### Python project - - Full .toml implementation - - .whl automatic generation - - dynamic versionning using git repository - - embedded unit-test \ No newline at end of file +Checkout [Latest Documentation](https://chacha.ddns.net/mkdocs-web/chacha/pyrestresource/master/latest/). \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cfea0e9..04f614b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ version_scheme= "post-release" name = "pyrestresource" description = "pyrestresource" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.11" keywords = ["chacha","chacha","template","pyrestresource"] license = { file = "LICENSE.md" } @@ -30,11 +30,12 @@ maintainers = [ classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.11", ] dependencies = [ - 'importlib-metadata; python_version<"3.9"', - 'packaging' + 'packaging', + 'pydantic>=2.4,<3', + 'uvicorn>=0.23' ] dynamic = ["version"] diff --git a/src/pyrestresource/__init__.py b/src/pyrestresource/__init__.py index 7805b58..6ee712e 100644 --- a/src/pyrestresource/__init__.py +++ b/src/pyrestresource/__init__.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + # pyrestresource (c) by chacha # # pyrestresource is licensed under a @@ -5,32 +8,48 @@ # # You should have received a copy of the license along with this # work. If not, see . +# pylint: disable=wrong-import-position """ Main module __init__ file. """ -from importlib.metadata import version, distribution, PackageNotFoundError -import warnings +from typing import TYPE_CHECKING -from .test_module import test_function +from .__metadata__ import __version__, __Summuary__, __Name__ -try: # pragma: no cover - __version__ = version("pyrestresource") -except PackageNotFoundError: # pragma: no cover - warnings.warn("can not read __version__, assuming local test context, setting it to ?.?.?") - __version__ = "?.?.?" -try: # pragma: no cover - dist = distribution("pyrestresource") - __Summuary__ = dist.metadata["Summary"] -except PackageNotFoundError: # pragma: no cover - warnings.warn('can not read dist.metadata["Summary"], assuming local test context, setting it to ') - __Summuary__ = "pyrestresource description" +from .rest_resource import ( + register_rest_rootpoint, + RestResourceBase, +) -try: # pragma: no cover - dist = distribution("pyrestresource") - __Name__ = dist.metadata["Name"] -except PackageNotFoundError: # pragma: no cover - warnings.warn('can not read dist.metadata["Name"], assuming local test context, setting it to ') - __Name__ = "pyrestresource" +from .rest_types import rsrc_verb, T_SupportedRESTFields + +if TYPE_CHECKING: + from .rest_types import ( + T_ListIndex, + T_ListSize, + T_DictKey, + T_T_DictKey, + T_DictValues, + T_T_DictValues, + ) + +from .rest_request_opt import ( + RestRequestParams_POST, + RestRequestParams_DELETE, + RestRequestParams_GET, + RestRequestParams_PUT, + RestRequestParams_RestResourceBase_PUT, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, +) + +from .rest_resource_plugin import ( + ResourcePlugin_field_default, + ResourcePlugin_RestResourceBase_default, + ResourcePlugin_dict_default, +) diff --git a/src/pyrestresource/__metadata__.py b/src/pyrestresource/__metadata__.py new file mode 100644 index 0000000..cfb1763 --- /dev/null +++ b/src/pyrestresource/__metadata__.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# pyrestresource(c) by chacha +# +# pyrestresource is licensed under a +# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. +# +# You should have received a copy of the license along with this +# work. If not, see . + +"""Metadata module module""" + +from importlib.metadata import version, distribution, PackageNotFoundError +import warnings + + +try: # pragma: no cover + __version__ = version("pyrestresource") +except PackageNotFoundError: # pragma: no cover + warnings.warn( + "can not read __version__, assuming local test context, setting it to ?.?.?" + ) + __version__ = "?.?.?" + +try: # pragma: no cover + dist = distribution("pyrestresource") + __Summuary__ = dist.metadata["Summary"] +except PackageNotFoundError: # pragma: no cover + warnings.warn( + 'can not read dist.metadata["Summary"], assuming local test context, setting it to ' + ) + __Summuary__ = "pyrestresource description" + +try: # pragma: no cover + dist = distribution("pyrestresource") + __Name__ = dist.metadata["Name"] +except PackageNotFoundError: # pragma: no cover + warnings.warn( + 'can not read dist.metadata["Name"], assuming local test context, setting it to ' + ) + __Name__ = "pyrestresource" diff --git a/src/pyrestresource/data/.keep b/src/pyrestresource/data/.keep deleted file mode 100644 index e69de29..0000000 diff --git a/src/pyrestresource/data/__init__.py b/src/pyrestresource/data/__init__.py deleted file mode 100644 index 0aef653..0000000 --- a/src/pyrestresource/data/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# pyrestresource (c) by chacha -# -# pyrestresource is licensed under a -# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. -# -# You should have received a copy of the license along with this -# work. If not, see . diff --git a/src/pyrestresource/helpers.py b/src/pyrestresource/helpers.py new file mode 100644 index 0000000..a39262c --- /dev/null +++ b/src/pyrestresource/helpers.py @@ -0,0 +1,17 @@ +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring + +from __future__ import annotations +from uuid import UUID +import json + +from .rest_types import T_Gen_DictKeys + + +class _JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, T_Gen_DictKeys): # pylint: disable=isinstance-second-argument-not-valid-type + return list(o) + if isinstance(o, UUID): + # if the obj is uuid, we simply return the value of uuid + return str(o) + return json.JSONEncoder.default(self, o) diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py new file mode 100644 index 0000000..951456d --- /dev/null +++ b/src/pyrestresource/rest_request.py @@ -0,0 +1,212 @@ +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring +"""A module to handle http requests context""" + +from __future__ import annotations +from typing import ( + Optional, + Generic, +) +from re import sub +from urllib.parse import urlparse, parse_qs +from pydantic import BaseModel, Field + + +from .rest_types import rsrc_verb, T_SupportedRESTFields + +from .rest_request_opt import ( + RestRequestParams_POST, + RestRequestParams_DELETE, + RestRequestParams_GET, + RestRequestParams_PUT, + _T_RestRequestParams, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, +) + + +class RequestFactory( + Generic[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ], + BaseModel, +): + """RestRequets class factory""" + + cls_RestRequestParams_GET: type[RestRequestParams_GET] = Field(default=RestRequestParams_GET) + cls_RestRequestParams_PUT: type[RestRequestParams_PUT] = Field(default=RestRequestParams_PUT) + cls_RestRequestParams_POST: type[RestRequestParams_POST] = Field(default=RestRequestParams_POST) + cls_RestRequestParams_DELETE: type[RestRequestParams_DELETE] = Field(default=RestRequestParams_DELETE) + + def get_RestRequest(self, url: str, verb: rsrc_verb, data: dict, query_string: Optional[str] = None) -> RestRequest: + """get a RestRequets instance based on LUT_verb configuration + + Args: + url: http url of the request + verb: http verb received + data: data associated with the request + """ + + # /!\ mypy seems not being able to propagate typevar to composed classes + if verb is rsrc_verb.GET: + return RestRequest[RestRequestParams_GET](self.cls_RestRequestParams_GET, url, verb, data, query_string) + if verb is rsrc_verb.PUT: + return RestRequest[RestRequestParams_PUT](self.cls_RestRequestParams_PUT, url, verb, data, query_string) + if verb is rsrc_verb.POST: + return RestRequest[RestRequestParams_POST](self.cls_RestRequestParams_POST, url, verb, data, query_string) + if verb is rsrc_verb.DELETE: + return RestRequest[RestRequestParams_DELETE](self.cls_RestRequestParams_DELETE, url, verb, data, query_string) + raise RuntimeError("Invalid Verb") + + def update_RestRequest(self, origin_request: RestRequest) -> RestRequest: + """create an updated copy of a RestRequest object based on a different LUT_verb configuration + Args: + origin_request: the original request + """ + + # /!\ mypy seems not being able to propagate typevar to composed classes + if origin_request.verb is rsrc_verb.GET: + return RestRequest[RestRequestParams_GET](self.cls_RestRequestParams_GET, None, None, None, None, origin_request) + if origin_request.verb is rsrc_verb.PUT: + return RestRequest[RestRequestParams_PUT](self.cls_RestRequestParams_PUT, None, None, None, None, origin_request) + if origin_request.verb is rsrc_verb.POST: + return RestRequest[RestRequestParams_POST](self.cls_RestRequestParams_POST, None, None, None, None, origin_request) + if origin_request.verb is rsrc_verb.DELETE: + return RestRequest[RestRequestParams_DELETE](self.cls_RestRequestParams_DELETE, None, None, None, None, origin_request) + raise RuntimeError("Invalid Verb") + + +class RestRequest(Generic[_T_RestRequestParams]): + # pylint: disable=too-many-instance-attributes + """Main RestRequets class""" + + def __init__( + self, + type_request_params: type[_T_RestRequestParams], + url: Optional[str] = None, + verb: Optional[rsrc_verb] = None, + data: Optional[dict[str, T_SupportedRESTFields]] = None, + query_string: Optional[str] = None, + origin_request: Optional[RestRequest] = None, + ) -> None: + """class to handle a request context, that will be kept and updated while walking url parts + + Args: + url: http url of the request + verb: http verb received + data: data associated with the request + type_request_params: type of the request param + origin_request: orginial request in case of updates. + In this case, all other argument - but type_request_params - are ignored and inherited from the origin_request + """ + + # defining all types + self.url: str + self.verb: rsrc_verb + self.data: dict + self._saved_url_params: dict + self.ReqParams: _T_RestRequestParams = type_request_params() + self.url_stack: list[str] + self._saved_url_stack: list[str] + self.url_stack_index: int + + # detecting Optional fields in type_request_params (to extract real type) + # if False: # deprecated => Generic + # if get_origin(type_request_params) is Union: + # datatype = get_args(type_request_params) + # if len(datatype) == 2: + # if datatype[0] is type(None): + # type_request_params = datatype[1] + # elif datatype[1] is type(None): + # type_request_params = datatype[0] + # else: + # raise RuntimeError("Union is only allowed to describe Optional (e.g. Union[XXX,None])") + + # = updating request from a previous one = + if origin_request: + self.__dict__ = origin_request.__dict__.copy() + if type_request_params: + self.ReqParams = type_request_params(**self._saved_url_params) + # print("request updated") + return + + # = or create a fresh one = + if url is None or verb is None or data is None: + raise RuntimeError("url and verb and data must be set") + self.url = url + self.verb = verb + self.data = data + + # parse_qs returns list[] for all keys, the command convert list to single items so pydantic can eat them :) + if query_string: + self._saved_url_params = dict((k, v if len(v) > 1 else v[0]) for k, v in parse_qs(query_string).items()) + else: + self._saved_url_params = dict((k, v if len(v) > 1 else v[0]) for k, v in parse_qs(urlparse(url).query).items()) + + if type_request_params: + self.ReqParams = type_request_params(**self._saved_url_params) # actual lunch + + self._parse_url(url) + + # keeping a backup of the original url stack + self._saved_url_stack = self.url_stack.copy() + self.url_stack_index = 0 + + def _parse_url(self, url: str) -> None: + # remove repeated slash ('/') + url = sub(r"\/{2,}", "/", url) + # root url need to be added manually because it is trimmed by the url split function + self.url_stack = ["/"] + self.url_stack.extend([_ for _ in urlparse(url).path.split("/") if _ != ""]) + + def reset_url_stack(self) -> None: + self.url_stack = self._saved_url_stack.copy() + self.url_stack_index = 0 + + def get_url_stack(self) -> list[str]: + """retrieve the current url stack""" + + return self.url_stack + + def get_url(self) -> str: + """retrieve the raw url""" + return self.url + + def consume_url_stack(self, count: int) -> list[str]: + """consume some url stack elements + + Args: + count: number of element to consume + """ + + returned_stack: list[str] = [] + + for _ in range(count): + returned_stack.append(self.url_stack.pop(0)) + self.url_stack_index = self.url_stack_index + count + return returned_stack + + def get_data(self) -> dict: + """get the request data""" + return self.data + + def get_resource_origin(self, deepness: int = 0) -> str: + """get current or previous (consumed) resource in the url + + Args: + deepness: backward amount + """ + + return self._saved_url_stack[self.url_stack_index - deepness] + + def get_req_params(self) -> _T_RestRequestParams: + """get extracted req_params""" + return self.ReqParams + + def get_verb(self) -> rsrc_verb: + """get http request verb""" + return self.verb diff --git a/src/pyrestresource/rest_request_opt.py b/src/pyrestresource/rest_request_opt.py new file mode 100644 index 0000000..10cc360 --- /dev/null +++ b/src/pyrestresource/rest_request_opt.py @@ -0,0 +1,73 @@ +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring + +from __future__ import annotations + +from typing import Optional, Generic, TypeVar +from pydantic import BaseModel, Extra + +from .rest_types import ( + _T_DictKey, +) + + +class RestRequestParams(BaseModel, extra=Extra.allow): + pass + + +class RestRequestParams_POST(RestRequestParams): + pass + + +class RestRequestParams_DELETE(RestRequestParams): + pass + + +class RestRequestParams_GET(RestRequestParams): + pass + + +class RestRequestParams_PUT(RestRequestParams): + pass + + +class RestRequestParams_RestResourceBase(RestRequestParams): + pass + + +class RestRequestParams_RestResourceBase_PUT(RestRequestParams_PUT, RestRequestParams_RestResourceBase): + pass + + +class RestRequestParams_RestResourceBase_GET(RestRequestParams_GET, RestRequestParams_RestResourceBase): + pass + + +class RestRequestParams_Dict(RestRequestParams): + pass + + +class RestRequestParams_Dict_POST(RestRequestParams_Dict, RestRequestParams_POST, Generic[_T_DictKey]): + API_key: Optional[_T_DictKey] = None + + +class RestRequestParams_Dict_DELETE(RestRequestParams_Dict, RestRequestParams_DELETE, Generic[_T_DictKey]): + API_key: Optional[_T_DictKey] = None + + +class RestRequestParams_Dict_GET(RestRequestParams_Dict, RestRequestParams_GET): + pass + + +class RestRequestParams_Dict_elem_GET(RestRequestParams_Dict, RestRequestParams_GET): + pass + + +class RestRequestParams_Dict_elem_PUT(RestRequestParams_Dict, RestRequestParams_GET): + pass + + +_T_RestRequestParams = TypeVar("_T_RestRequestParams", bound=RestRequestParams) +_T_RestRequestParams_POST = TypeVar("_T_RestRequestParams_POST", bound=RestRequestParams_POST) +_T_RestRequestParams_DELETE = TypeVar("_T_RestRequestParams_DELETE", bound=RestRequestParams_DELETE) +_T_RestRequestParams_GET = TypeVar("_T_RestRequestParams_GET", bound=RestRequestParams_GET) +_T_RestRequestParams_PUT = TypeVar("_T_RestRequestParams_PUT", bound=RestRequestParams_PUT) diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py new file mode 100644 index 0000000..987d8d2 --- /dev/null +++ b/src/pyrestresource/rest_resource.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# pyrestresource(c) by chacha +# +# pyrestresource is licensed under a +# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. +# +# You should have received a copy of the license along with this +# work. If not, see . + +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring + +"""CLI interface module""" +from __future__ import annotations + +from abc import ABC +from typing import ( + cast, + ClassVar, + get_args, + get_origin, + Optional, + TYPE_CHECKING, +) +import json +from pydantic.fields import FieldInfo +from pydantic import BaseModel + +from .helpers import _JSONEncoder +from .rest_types import rsrc_verb, _T_SupportedRESTFields +from .rest_resource_plugin import ( + ResourcePlugin_field, + ResourcePlugin_RestResourceBase, + ResourcePlugin_dict, +) + + +from .rest_resource_walker import ( + RestResourceWalkerFutureResult, + RestResourceWalker_Root, + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, +) + +if TYPE_CHECKING: + from .rest_types import ( + T_ListIndex, + T_ListSize, + T_DictKey, + T_T_DictKey, + T_DictValues, + T_T_DictValues, + T_SupportedRESTFields, + ) + + +class RestResourceWalkerFutureResult_RestResourceBase_tree_exclude(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + res = {} + res[self.source.resource_name] = dict() + for subres in result: + key = next(iter(subres)) + if ( + key in self.source.annotation._model_dump_excluded_ # pylint: disable=protected-access + and self.source.annotation._model_dump_excluded_[key] is True # pylint: disable=protected-access + ): + res[self.source.resource_name] = res[self.source.resource_name] | {key: True} + else: + res[self.source.resource_name] = res[self.source.resource_name] | subres + return res + + +class RestResourceWalkerFutureResult_Dict_tree_exclude(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + res = {} + for subres in result: + res = res | subres + return res + + +class RestResourceWalker_Sub_T_Dict__tree_exclude(RestResourceWalker_Sub_T_Dict): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_Dict_tree_exclude + + +class RestResourceWalker_Sub_RestResourceBase__tree_exclude(RestResourceWalker_Sub_RestResourceBase): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_RestResourceBase_tree_exclude + + +class RestResourceWalker_Root__tree_exclude(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict__tree_exclude, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase__tree_exclude, + ] + + +class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): + def process(self) -> None: + datatype = get_args(self.annotation) + + # checking compatibility + if not get_origin(datatype[1]) is None: + raise RuntimeError("complex dict types are not supported (should create a RestResourceBase container)") + if not datatype[0] in _T_SupportedRESTFields: + raise RuntimeError(f"Unsupported Dict Field value type in class (key)") + + # preprocessing types / structure + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + self.parent.annotation._dict_key_type_[self.resource_name] = datatype[0] # pylint: disable=protected-access + self.parent.annotation._dict_value_type_[self.resource_name] = datatype[1] # pylint: disable=protected-access + self.parent.annotation._model_dump_excluded_[self.resource_name] = True # pylint: disable=protected-access + + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + and "plugin" in self.resource.json_schema_extra + ): + plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"]() + if not isinstance(plugin_dict, ResourcePlugin_dict): + raise RuntimeError("Wrong plugin signature provided") + self.parent.annotation._plugins_[self.resource_name] = plugin_dict + # print("ADD DICT PLUGIN") + + else: + raise RuntimeError("dict must be contained in a RestResourceBase") + + +class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFields): + def process(self) -> None: + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + ): + if "primary_key" in self.resource.json_schema_extra and self.resource.json_schema_extra["primary_key"] is True: + if self.parent.annotation._primary_key_ is not None: + raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}") + self.parent.annotation._primary_key_ = self.resource_name + + if "plugin" in self.resource.json_schema_extra and self.resource.json_schema_extra["plugin"]: + plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"]() + if not isinstance(plugin_field, ResourcePlugin_field): + raise RuntimeError("Wrong plugin signature provided") + self.parent.annotation._plugins_[self.resource_name] = plugin_field + # print("ADD FIELD PLUGIN") + + else: + raise RuntimeError("fields must be contained in a RestResourceBase") + + +class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_RestResourceBase): + def process(self) -> None: + setattr(self.annotation, "_dict_key_type_", {}) + setattr(self.annotation, "_dict_value_type_", {}) + setattr(self.annotation, "_model_dump_excluded_", {}) + setattr(self.annotation, "_primary_key_", None) + setattr(self.annotation, "_plugins_", {}) + + # preprocessing types / structure + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + self.parent.annotation._model_dump_excluded_[self.resource_name] = True + + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + and "plugin" in self.resource.json_schema_extra + ): + plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"]() + if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase): + raise RuntimeError("Wrong plugin signature provided") + self.parent.annotation._plugins_[self.resource_name] = plugin_resource + # print("ADD RESOURCE PLUGIN") + + +class RestResourceWalker_Root__tree_init(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict__tree_init, + RestResourceWalker_Sub_RestFields__tree_init, + RestResourceWalker_Sub_RestResourceBase__tree_init, + ] + + +def register_rest_rootpoint(klass: type[RestResourceBase]): + RestResourceWalker_Root__tree_init(klass).process() + return klass + + +class RestResourceBase(ABC, BaseModel, validate_assignment=True): + _dict_key_type_: ClassVar[dict[str, T_T_DictKey]] = {} + _dict_value_type_: ClassVar[dict[str, T_T_DictValues]] = {} + _model_dump_excluded_: ClassVar[dict[str, bool]] = {} + _primary_key_: ClassVar[Optional[str]] = None + _plugins_: ClassVar[ + dict[ + str, + ResourcePlugin_field | ResourcePlugin_RestResourceBase | ResourcePlugin_dict, + ] + ] = {} + + def update(self, **new_data): + for field, value in new_data.items(): + setattr(self, field, value) + + async def read_body(self, receive): + """ + Read and return the entire body from an incoming ASGI message. + """ + body = b"" + more_body = True + + while more_body: + message = await receive() + body += message.get("body", b"") + more_body = message.get("more_body", False) + + return body + + async def __call__(self, scope, receive, send): + assert scope["type"] == "http" + method = scope["method"] + assert method in ["GET", "DELETE", "PUT", "POST"] + if b"content-type" in scope["headers"]: + assert scope["headers"][b"content-type"] == b"application/json" + + body = await self.read_body(receive) + verb = rsrc_verb[scope["method"]] + result = self.process_request( + scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8") + ) + + status = 200 + if verb in (rsrc_verb.POST, rsrc_verb.PUT): + status = 201 + + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + [b"content-type", b"application/json"], + ], + } + ) + body = None + if result: + body = result.encode("utf-8") + await send( + { + "type": "http.response.body", + "body": body, + } + ) + + def process_request( + self, url: str, verb: rsrc_verb = rsrc_verb.GET, data_json: Optional[str] = None, query_string: Optional[str] = None + ) -> Optional[str]: + from .rest_resource_handler import ( + ResourceHandler, + ResourceHandler_RestResourceBase, + ) + + data: dict = {} + if data_json: + data = json.loads(data_json) + + ressource: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string) + result = ressource.process_verb() + + if isinstance(result, RestResourceBase): + exclude: Optional[dict[str, bool]] = None + raw_exclude = RestResourceWalker_Root__tree_exclude(result).process() + exclude = next(iter(raw_exclude.values())) + return json.dumps(result.model_dump(mode="json", exclude=exclude)) + + if result is not None: + return json.dumps(result, cls=_JSONEncoder) + return None diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py new file mode 100644 index 0000000..56d304d --- /dev/null +++ b/src/pyrestresource/rest_resource_handler.py @@ -0,0 +1,631 @@ +from __future__ import annotations +import abc +from typing import Optional, cast, TypeVar, Generic, Self, TYPE_CHECKING + +from .rest_types import ( + rsrc_verb, + T_SupportedRESTFields, + T_DictKey, + _T_SupportedRESTFields, + T_Dict, + T_T_DictValues, + T_DictValues, +) +from .rest_resource import RestResourceBase +from .rest_request import RequestFactory, RestRequest + +from .rest_resource_plugin import ( + ResourcePlugin_field, + ResourcePlugin_RestResourceBase, +) + +from .rest_request_opt import ( + RestRequestParams_POST, + RestRequestParams_DELETE, + RestRequestParams_GET, + RestRequestParams_PUT, + RestRequestParams_RestResourceBase_PUT, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, +) + +from .rest_resource_handler_walker import RestResourceWalker_Root__handler + +if TYPE_CHECKING: + from .rest_types import ( + T_ListIndex, + T_ListSize, + T_T_DictKey, + T_FieldValue, + ) + + +_T_Resource = TypeVar("_T_Resource", T_DictValues, T_Dict, T_SupportedRESTFields, RestResourceBase) + + +class ResourceHandler( + abc.ABC, + Generic[ + _T_Resource, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ], +): + _ar_resource_handler_cls_: list[type[ResourceHandler]] = [] + _nb_url_element_to_consume_ = 1 + + _request_factory: RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ] = RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ]( + cls_RestRequestParams_GET=RestRequestParams_GET, + cls_RestRequestParams_PUT=RestRequestParams_PUT, + cls_RestRequestParams_POST=RestRequestParams_POST, + cls_RestRequestParams_DELETE=RestRequestParams_DELETE, + ) + + def __init__( + self, + resource: _T_Resource, + url: Optional[str] = None, + verb: Optional[rsrc_verb] = None, + data: Optional[dict] = None, + query_string: Optional[str] = None, + prev_handler: Optional[ResourceHandler] = None, + ) -> None: + self.prev_handler: Optional[ResourceHandler] = None + self.next_handler: Optional[ResourceHandler] = None + self.saved_url: list[str] = [] + self.resource: _T_Resource = resource + self.req: RestRequest + if prev_handler is not None: + self.prev_handler = prev_handler + self.req = self._request_factory.update_RestRequest(self.prev_handler.req) + elif None in [url, verb]: + raise RuntimeError("if req not set, url,verb must be setted") + else: + if url is None or verb is None: + raise RuntimeError("url and verb must be set") + if data is None: + data = {} + self.req = self._request_factory.get_RestRequest(url, verb, data, query_string) + + # print(f"[TRACE] creating {type(self).__name__}() with url={self.req.get_url_stack()}") + + @classmethod + def create_chained_handler(cls, other: ResourceHandler, resource: _T_Resource) -> Self: + return cls(resource, None, None, None, None, other) + + @classmethod + @abc.abstractmethod + def _check_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> bool: + return False + + @classmethod + def _get_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> type[ResourceHandler]: + for resource_handler_cls in cls._ar_resource_handler_cls_: + if resource_handler_cls._check_resource_handler(resource, req): + # print(f"[DEBUG] match ResourceHandler: {resource_handler_cls.__name__}") + return resource_handler_cls + raise RuntimeError(f"Unsupported Resource Type {type(resource).__name__}") + + @classmethod + def register_resource_handler(cls, other_cls) -> None: + cls._ar_resource_handler_cls_.append(other_cls) + return other_cls + + def process_verb( + self, + ) -> Optional[_T_Resource | T_DictKey | list[T_DictKey]]: + # print(f"[TRACE] {type(self).__name__}->process_verb()") + self._reset_context() + resource_handler = self._find_resource() + return resource_handler._process_verb() + + def access_resource( + self, + ) -> _T_Resource: + # print(f"[TRACE] {type(self).__name__}->access_resource()") + self._reset_context() + resource_handler = self._find_resource() + return resource_handler.resource + + def _reset_context(self) -> None: + self.req.reset_url_stack() + + def _find_resource( + self, + ) -> ResourceHandler[ + _T_Resource, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ]: + # print(f"[TRACE] {type(self).__name__}->_find_resource()") + # print(f"[DEBUG] {type(self).__name__}->resource = {type(self.resource).__name__}") + + if len(self.req.get_url_stack()) == 0: + return self + + self._check_access_rights() + + next_resource = self._process_get() + # reveal_type(next_resource) + _next_resource = cast(_T_Resource, next_resource) + # reveal_type(_next_resource) + # print(f"[DEBUG] next_resource = {type(next_resource).__name__}") + + if ( + isinstance(_next_resource, RestResourceBase) + or isinstance(_next_resource, dict) + or type(_next_resource) in _T_SupportedRESTFields + ): + next_resource_handler_cls: type[ResourceHandler] = self._get_resource_handler(_next_resource, self.req) + + self.saved_url = self.req.consume_url_stack(self._nb_url_element_to_consume_) + + # we always create a new ResourceHandler context because we might want to access + # previous saved/chained ones (for exemple to put/post/delete values in containers) + # if next_resource_handler_cls != type(self): + # print(f"[DEBUG] CHANGING HANDLER to {next_resource_handler_cls.__name__}") + + next_resource_handler: ResourceHandler = next_resource_handler_cls.create_chained_handler(self, _next_resource)._find_resource() + self.next_handler = next_resource_handler + return next_resource_handler + + # in the context of _find_resource, only resource real values can be retrieved + raise RuntimeError("Wrong request") + + def _check_access_rights(self): + pass + + def _process_verb( + self, + ) -> Optional[_T_Resource | T_DictKey | list[T_DictKey]]: + # print(f"[TRACE] {type(self).__name__}->_process_verb()") + + verb = self.req.get_verb() + + if verb is rsrc_verb.GET: + return self._process_get() + if verb is rsrc_verb.PUT: + self._process_put() + return None + if verb is rsrc_verb.POST: + return self._process_post() + if verb is rsrc_verb.DELETE: + self._process_delete() + return None + + raise RuntimeError("Invalid Verb") + + def _process_get( + self, + ) -> _T_Resource | list[T_DictKey]: + return self._handle_process_get(self.req.get_req_params()) + + def _process_put(self) -> None: + self._handle_process_put(self.req.get_req_params()) + + def _process_post( + self, + ) -> Optional[T_DictKey]: + return self._handle_process_post(self.req.get_req_params()) + + def _process_delete( + self, + ) -> None: + self._handle_process_delete(self.req.get_req_params()) + + def _handle_process_get(self, params: _T_RestRequestParams_GET) -> _T_Resource | list[T_DictKey]: + raise RuntimeError(f"GET method not implemented for {type(self).__name__}") + + def _handle_process_put(self, params: _T_RestRequestParams_PUT) -> None: + raise RuntimeError(f"PUT method not implemented for {type(self).__name__}") + + def _handle_process_post(self, params: _T_RestRequestParams_POST) -> Optional[T_DictKey]: + raise RuntimeError(f"POST method not implemented for {type(self).__name__}") + + def _handle_process_delete(self, params: _T_RestRequestParams_DELETE) -> None: + raise RuntimeError(f"DELETE method not implemented for {type(self).__name__}") + + +@ResourceHandler.register_resource_handler +class ResourceHandler_dict( + ResourceHandler[ + T_Dict, + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, + _T_RestRequestParams_PUT, + ] +): + _nb_url_element_to_consume_ = 0 + + _request_factory: RequestFactory[ + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, + _T_RestRequestParams_PUT, + ] = RequestFactory[ + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, + _T_RestRequestParams_PUT, + ]( + cls_RestRequestParams_GET=RestRequestParams_Dict_GET, + cls_RestRequestParams_POST=RestRequestParams_Dict_POST, + cls_RestRequestParams_DELETE=RestRequestParams_Dict_DELETE, + ) + + @classmethod + def _check_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> bool: + # print(f"{cls.__name__}->_check_resource_handler()") + + return isinstance(resource, dict) and len(req.get_url_stack()) == 1 + + def _handle_process_get(self, params) -> list[T_DictKey]: + # print(f"{type(self).__name__}->_process_get()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + _dict: dict[T_DictKey, T_DictValues] = cast(dict[T_DictKey, T_DictValues], self.resource) + + return list(_dict.keys()) + + def _handle_process_delete(self, params) -> None: + # print(f"{type(self).__name__}->_handle_process_delete()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + if self.prev_handler is None: + raise RuntimeError("Wrong command") + + dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] + + _dict: dict[T_DictKey, "T_DictValues"] = cast(dict[T_DictKey, "T_DictValues"], self.resource) + + if params.API_key is not None: + del _dict[dict_key_type(params.API_key)] + else: + _dict.clear() + return + + def _handle_process_post(self, params) -> Optional[T_DictKey]: + # pylint: disable=protected-access + + # print(f"{type(self).__name__}->_handle_process_post()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + if self.prev_handler is None: + raise RuntimeError("Wrong command") + + dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] + dict_value_type: T_T_DictValues = cast(RestResourceBase, self.prev_handler.resource)._dict_value_type_[ + self.req.get_resource_origin(1) + ] + + _obj = dict_value_type(**self.req.get_data()) + + _dict: dict[T_DictKey, "T_DictValues"] = cast(dict[T_DictKey, "T_DictValues"], self.resource) + + # 1st try/ using request param provided dict API_key + if params.API_key is not None: + # if a primary key is set for the resource, updating it + if isinstance(_obj, RestResourceBase): + if _obj._primary_key_ is not None: + _pri: T_DictKey = dict_key_type(params.API_key) + setattr(_obj, _obj._primary_key_, _pri) + # storing resource + _dict[dict_key_type(params.API_key)] = _obj + return dict_key_type(params.API_key) + + # 2nd try/ using provided resource internal primary key + # & 3rd try/ using resource internal auto-generated primary key + # => this case is automatic because if self.req.get_data() doesn't contain the key, it should be automatically created + if isinstance(_obj, RestResourceBase): + if _obj._primary_key_ is not None: + _obj_primary_key: Optional[T_DictKey] = getattr(_obj, _obj._primary_key_) + if _obj_primary_key is not None: + _dict[_obj_primary_key] = _obj + return _obj_primary_key + + RuntimeError("Either the object needs defined primary key or the request must contain an API_key param to process this command") + return None # for mypy.... + + +@ResourceHandler.register_resource_handler +class ResourceHandler_dict_elem( + ResourceHandler[ + T_DictValues, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + _T_RestRequestParams_PUT, + ] +): + _nb_url_element_to_consume_ = 1 + + _request_factory: RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + _T_RestRequestParams_PUT, + ] = RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + _T_RestRequestParams_PUT, + ]( + cls_RestRequestParams_GET=RestRequestParams_RestResourceBase_GET + ) + + @classmethod + def _check_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> bool: + # print(f"{cls.__name__}->_check_resource_handler()") + + return isinstance(resource, dict) and len(req.get_url_stack()) > 1 + + def _handle_process_get(self, params) -> T_DictValues: + # print(f"{type(self).__name__}->_process_get()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + if self.prev_handler is None: + raise RuntimeError("Wrong command") + + dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] + + if issubclass(dict_key_type, bytes): + key_byte = dict_key_type(self.req.get_resource_origin(0), "utf-8") + return cast(dict[T_DictKey, T_DictValues], self.resource)[key_byte] + else: + key = dict_key_type(self.req.get_resource_origin(0)) + return cast(dict[T_DictKey, T_DictValues], self.resource)[key] + + def _handle_process_delete(self, params) -> None: + # print(f"{type(self).__name__}->_handle_process_delete()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + # this is a litle bit tricky because this call comes from a next resource, so get_resource_origin(2) + # instead of expected get_resource_origin(1) because we need to go backward + # because self.req is another context that is not saved to improve performances + + if self.prev_handler is None: + raise RuntimeError("Wrong command") + + dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(2)] + + if issubclass(dict_key_type, bytes): + key_byte = dict_key_type(self.req.get_resource_origin(1), "utf-8") + del cast(dict[T_DictKey, T_DictValues], self.resource)[key_byte] + else: + key = dict_key_type(self.req.get_resource_origin(1)) + del cast(dict[T_DictKey, T_DictValues], self.resource)[key] + + +@ResourceHandler.register_resource_handler +class ResourceHandler_RestResourceBase( + ResourceHandler[ + RestResourceBase, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_RestResourceBase_PUT, + ] +): + _request_factory: RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_RestResourceBase_PUT, + ] = RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_RestResourceBase_PUT, + ]( + cls_RestRequestParams_GET=RestRequestParams_RestResourceBase_GET, + cls_RestRequestParams_PUT=RestRequestParams_RestResourceBase_PUT, + ) + + @classmethod + def _check_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> bool: + # print(f"{cls.__name__}->_check_resource_handler()") + + return isinstance(resource, RestResourceBase) + + def _check_access_rights(self) -> None: + super()._check_access_rights() + # print(f"{type(self).__name__}->_check_access_rights()") + + if self.req.get_resource_origin(0) == "/": + return + + if ( + self.req.get_resource_origin(0) not in self.resource.model_fields + or self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True + ): + raise RuntimeError(f"Unknown or not allowed field access detected: {self.req.get_url_stack()}") + + def _handle_process_get(self, params) -> RestResourceBase: + # print(f"{type(self).__name__}->_process_get()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + # CASE 1: no more item in url_stack => we reached the endpoint + # So we are in a RestResourceBase instance and must return the content + if len(self.req.get_url_stack()) == 0: + for key, attr in self.resource.model_fields.items(): + if key in self.resource._plugins_: + if isinstance(self.resource._plugins_[key], ResourcePlugin_field): + plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, self.resource._plugins_[key]) + value = getattr(self.resource, key) + setattr(self.resource, key, plugin_field.handle_field_get(value, params)) + elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): + plugin_field: ResourcePlugin_field = cast(ResourcePlugin_RestResourceBase, self.resource._plugins_[key]) + value = getattr(self.resource, key) + setattr(self.resource, key, plugin_field.handle_resource_get(value, params)) + + # result = RestResourceWalker_Root__handler(self.resource).process() + # print(result) + return self.resource + + # CASE 2: specific case for root Node + # TODO: this must probably be merged with the previous bloc + if self.req.get_resource_origin(0) == "/": + return self.resource + + # CASE 3: in between + value = getattr(self.resource, self.req.get_resource_origin(0)) + + key = self.req.get_resource_origin(0) + if key in self.resource._plugins_: + if isinstance(self.resource._plugins_[key], ResourcePlugin_field): + plugin_rsrc: ResourcePlugin_RestResourceBase = cast( + ResourcePlugin_RestResourceBase, + self.resource._plugins_[key], + ) + value = plugin_rsrc.handle_field_get(value, params) + + elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): + plugin_rsrc: ResourcePlugin_RestResourceBase = cast( + ResourcePlugin_RestResourceBase, + self.resource._plugins_[key], + ) + value = plugin_rsrc.handle_resource_get(value, params) + + return value + + def _handle_process_put(self, params) -> None: + # print(f"{type(self).__name__}->_process_put()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + # creating a copy of the current resource + _new_resrc = self.resource.copy() + # updating values based on nex data + _new_resrc.update(**self.req.get_data()) + + # applying plugins (to nested element) + if isinstance(_new_resrc, RestResourceBase): + for key, attr in _new_resrc.model_fields.items(): + if key in _new_resrc._plugins_: + if isinstance(_new_resrc._plugins_[key], ResourcePlugin_field): + plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, _new_resrc._plugins_[key]) + value = getattr(_new_resrc, key) + setattr(_new_resrc, key, plugin_field.handle_field_put(value, params)) + + # applying plugins (from parent element) + if self.prev_handler is not None: + # element is within a dict + if ( + isinstance(self.prev_handler.resource, dict) + and self.prev_handler.prev_handler is not None + and isinstance(self.prev_handler.prev_handler.resource, RestResourceBase) + ): + key = self.req.get_resource_origin(2) + if key in self.prev_handler.prev_handler.resource._plugins_: + plugin_rsrc: ResourcePlugin_RestResourceBase = cast( + ResourcePlugin_RestResourceBase, + self.prev_handler.prev_handler.resource._plugins_[key], + ) + _new_resrc = plugin_rsrc.handle_dict_elem_put(_new_resrc, params) + # element is within a RestResourceBase + elif isinstance(self.prev_handler.resource, RestResourceBase): + key = self.req.get_resource_origin(1) + if key in self.prev_handler.resource._plugins_: + plugin_rsrc: ResourcePlugin_RestResourceBase = cast( + ResourcePlugin_RestResourceBase, + self.prev_handler.resource._plugins_[key], + ) + _new_resrc = plugin_rsrc.handle_resource_put(_new_resrc, params) + + self.resource.update(**_new_resrc.dict()) + return + + def _handle_process_delete(self, params) -> None: + # print(f"{type(self).__name__}->_handle_process_delete()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + # DELETING an element can only be done from a dict => checking and forwarding + if ( + self.prev_handler is not None + and isinstance(self.prev_handler.resource, dict) + and self.prev_handler.prev_handler is not None + and isinstance(self.prev_handler.prev_handler.resource, RestResourceBase) + ): + self.prev_handler._process_delete() + else: + raise RuntimeError("cannot delete an element outside a dict") + + +@ResourceHandler.register_resource_handler +class ResourceHandler_simple( + ResourceHandler[ + T_SupportedRESTFields, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ] +): + @classmethod + def _check_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> bool: + # print(f"{cls.__name__}->_check_resource_handler()") + + return type(resource) in _T_SupportedRESTFields + + def _handle_process_get(self, params) -> T_SupportedRESTFields: + # print(f"{type(self).__name__}->_process_get()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + assert self.prev_handler is not None + assert isinstance(self.prev_handler.resource, RestResourceBase) + + if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: + plugin_simple: ResourcePlugin_field = cast( + ResourcePlugin_field, + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)], + ) + return plugin_simple.handle_field_get(self.resource, params) + + return self.resource + + def _handle_process_put(self, params) -> None: + # print(f"{type(self).__name__}->_process_put()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + assert self.prev_handler is not None + assert isinstance(self.prev_handler.resource, RestResourceBase) + + value = self.req.get_data() + + if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: + print("PLUGIN FOUND") + plugin_simple: ResourcePlugin_field = cast( + ResourcePlugin_field, + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)], + ) + print(value) + value = plugin_simple.handle_field_put(value, params) + print(value) + + print(self.req.get_resource_origin(1)) + setattr( + self.prev_handler.resource, + self.req.get_resource_origin(1), + value, + ) + print(self.prev_handler.resource) diff --git a/src/pyrestresource/rest_resource_handler_walker.py b/src/pyrestresource/rest_resource_handler_walker.py new file mode 100644 index 0000000..e23ec6f --- /dev/null +++ b/src/pyrestresource/rest_resource_handler_walker.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# pyrestresource(c) by chacha +# +# pyrestresource is licensed under a +# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. +# +# You should have received a copy of the license along with this +# work. If not, see . + +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring + +"""CLI interface module""" +from __future__ import annotations + +from typing import ( + ClassVar, + get_args, + get_origin, + Optional, + TYPE_CHECKING, +) + +from .rest_resource_walker import ( + RestResourceWalkerFutureResult, + RestResourceWalker_Root, + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, +) + + +class RestResourceWalkerFutureResult_RestResourceBase_handler(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + print(f"RestResourceWalkerFutureResult_RestResourceBase_handler {result}") + res = {} + res[self.source.resource_name] = dict() + for subres in result: + key = next(iter(subres)) + print(key) + res[self.source.resource_name] = res[self.source.resource_name] | subres + return res + + +class RestResourceWalkerFutureResult_Dict_handler(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + print(f"RestResourceWalkerFutureResult_Dict_handler {result}") + res = {} + for subres in result: + res = res | subres + return res + + +class RestResourceWalkerFutureResult_RestFields_handler(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + print(f"RestResourceWalkerFutureResult_RestFields_handler {result}") + print(self.source.resource) + res = {} + res[self.source.resource_name] = dict() + for subres in result: + key = next(iter(subres)) + print(key) + res[self.source.resource_name] = res[self.source.resource_name] | subres + return res + + +class RestResourceWalker_Sub_T_Dict__handler(RestResourceWalker_Sub_T_Dict): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_Dict_handler + + +class RestResourceWalker_Sub_RestResourceBase__handler(RestResourceWalker_Sub_RestResourceBase): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_RestResourceBase_handler + + +class RestResourceWalker_Sub_RestResourceFields__handler(RestResourceWalker_Sub_RestFields): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_RestFields_handler + + +class RestResourceWalker_Root__handler(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict__handler, + RestResourceWalker_Sub_RestResourceFields__handler, + RestResourceWalker_Sub_RestResourceBase__handler, + ] diff --git a/src/pyrestresource/rest_resource_plugin.py b/src/pyrestresource/rest_resource_plugin.py new file mode 100644 index 0000000..10a5570 --- /dev/null +++ b/src/pyrestresource/rest_resource_plugin.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +from typing import Optional, Protocol, runtime_checkable, TYPE_CHECKING +from abc import abstractmethod + +from .rest_types import ( + _T_DictValues, + _T_DictKey, + TV_SupportedRESTFields, + TV_RestResourceBase, +) + + +if TYPE_CHECKING or True: + from .rest_request_opt import ( + RestRequestParams_GET, + RestRequestParams_PUT, + RestRequestParams_RestResourceBase_PUT, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, + RestRequestParams_Dict_elem_GET, + RestRequestParams_Dict_elem_PUT, + ) + + +@runtime_checkable +class ResourcePlugin_field(Protocol[TV_SupportedRESTFields]): + @abstractmethod + def handle_field_get(self, resource: TV_SupportedRESTFields, params: RestRequestParams_GET) -> TV_SupportedRESTFields: + ... + + @abstractmethod + def handle_field_put(self, resource: TV_SupportedRESTFields, params: RestRequestParams_PUT) -> TV_SupportedRESTFields: + ... + + +class ResourcePlugin_field_default(ResourcePlugin_field[TV_SupportedRESTFields]): + """default implementation of RestResourcePlugin_simple""" + + def handle_field_get(self, resource: TV_SupportedRESTFields, params: RestRequestParams_GET) -> TV_SupportedRESTFields: + return resource + + def handle_field_put(self, resource: TV_SupportedRESTFields, params: RestRequestParams_PUT) -> TV_SupportedRESTFields: + return resource + + +@runtime_checkable +class ResourcePlugin_RestResourceBase(Protocol[TV_RestResourceBase]): + @abstractmethod + def handle_resource_get( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_RestResourceBase_GET, + ) -> TV_RestResourceBase: + ... + + @abstractmethod + def handle_resource_put( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_RestResourceBase_PUT, + ) -> TV_RestResourceBase: + ... + + +class ResourcePlugin_RestResourceBase_default(ResourcePlugin_RestResourceBase[TV_RestResourceBase]): + """default implementation of RestResourcePlugin_RestResourceBase""" + + def handle_resource_get( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_RestResourceBase_GET, + ) -> TV_RestResourceBase: + return resource + + def handle_resource_put( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_RestResourceBase_PUT, + ) -> TV_RestResourceBase: + return resource + + +@runtime_checkable +class ResourcePlugin_dict(Protocol[_T_DictKey, _T_DictValues]): + @abstractmethod + def handle_dict_get_keys( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + params: RestRequestParams_Dict_GET, + ) -> list[_T_DictKey]: + ... + + @abstractmethod + def handle_dict_post( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + resource: _T_DictValues, + params: RestRequestParams_Dict_POST[_T_DictKey], + ) -> Optional[_T_DictKey]: + ... + + @abstractmethod + def handle_dict_delete( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + params: RestRequestParams_Dict_DELETE[_T_DictKey], + ) -> None: + ... + + @abstractmethod + def handle_dict_elem_get( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_Dict_elem_GET, + ) -> TV_RestResourceBase: + ... + + @abstractmethod + def handle_dict_elem_put( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_Dict_elem_PUT, + ) -> TV_RestResourceBase: + ... + + +class ResourcePlugin_dict_default(ResourcePlugin_dict[_T_DictKey, _T_DictValues]): + """default implementation of RestResourcePlugin_dict""" + + def handle_dict_get_keys( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + params: RestRequestParams_Dict_GET, + ) -> list[_T_DictKey]: + return list(resource_dict.keys()) + + def handle_dict_post( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + resource: _T_DictValues, + params: RestRequestParams_Dict_POST[_T_DictKey], + ) -> Optional[_T_DictKey]: + if params.API_key is not None: + resource_dict[params.API_key] = resource + return params.API_key + + def handle_dict_delete( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + params: RestRequestParams_Dict_DELETE[_T_DictKey], + ) -> None: + if params.API_key is not None: + del resource_dict[params.API_key] + + def handle_dict_elem_get( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_Dict_elem_GET, + ) -> TV_RestResourceBase: + return resource + + def handle_dict_elem_put( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_Dict_elem_PUT, + ) -> TV_RestResourceBase: + return resource diff --git a/src/pyrestresource/rest_resource_walker.py b/src/pyrestresource/rest_resource_walker.py new file mode 100644 index 0000000..caa247c --- /dev/null +++ b/src/pyrestresource/rest_resource_walker.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from typing import ( + cast, + Any, + Optional, + Union, + get_args, + get_origin, + TypeVar, + Generic, + TYPE_CHECKING, +) +from typing import Type +from abc import ABC, abstractmethod + +from pydantic.fields import FieldInfo + +from .rest_types import _T_SupportedRESTFields + + +if TYPE_CHECKING: + from .rest_resource import RestResourceBase + +TV_RestResourceWalkerFutureResult = TypeVar("TV_RestResourceWalkerFutureResult") + + +class RestResourceWalkerFutureResult(ABC, Generic[TV_RestResourceWalkerFutureResult]): + def __init__(self, source: RestResourceWalker_Sub): + self.source: RestResourceWalker_Sub = source + + def chain_process_future(self) -> Optional[TV_RestResourceWalkerFutureResult]: + return self.source.chain_process_future() + + @abstractmethod + def process_future(self, result: Optional[list[TV_RestResourceWalkerFutureResult]]) -> Optional[TV_RestResourceWalkerFutureResult]: + pass + + +class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): + cls_RestResourceWalkerFutureResult: Optional[type[RestResourceWalkerFutureResult[TV_RestResourceWalkerFutureResult]]] = None + + @classmethod + @abstractmethod + def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: + """implementation interface to Factory. + The factory will call this specialized method on each implementation to find a supported one. + """ + ... + + @classmethod + def get( + self, + subs: list[type[RestResourceWalker_Sub]], + resource_name: str, + resource: FieldInfo | Type["RestResourceBase"], + parent: Optional[RestResourceWalker_Sub] = None, + ) -> Optional[RestResourceWalker_Sub]: + for sub in subs: + _is_valid, _anno, _optional = sub.check_type(resource) + + if _is_valid is True: + return sub(resource_name, resource, parent, _anno, _optional) + raise RuntimeError(f"Incompatible Field Found: {type(resource).__name__}") + return None + + def __init__( + self, + resource_name: str, + resource: FieldInfo | Type["RestResourceBase"], + parent: Optional[RestResourceWalker_Sub] = None, + annotation: Optional[type["RestResourceBase"]] = None, + optional: Optional[bool] = None, + ): + self.resource_name: str = resource_name + self.resource: FieldInfo | Type["RestResourceBase"] = resource + self.parent: Optional[RestResourceWalker_Sub] = parent + + self.future_results_subs: Optional[list[RestResourceWalkerFutureResult[TV_RestResourceWalkerFutureResult]]] = None + self.future_result: Optional[RestResourceWalkerFutureResult[TV_RestResourceWalkerFutureResult]] = None + if self.cls_RestResourceWalkerFutureResult is not None: + self.future_results_subs = [] + self.future_result = self.cls_RestResourceWalkerFutureResult(self) + + self.annotation: type["RestResourceBase"] + self.optional: bool + if annotation is None or optional is None: + self.annotation, self.optional = self.ProcessAnnotation(resource) + else: + self.annotation = annotation + self.optional = optional + + if self.annotation is None: + raise RuntimeError("Only annotated types are allowed in RestResourceBase derived classes") + + self.subdatatype = get_args(self.annotation) + + # self.info() + + def info(self) -> None: + print(f"{type(self).__name__}->info()") + print("==========================") + print(f"resource_name: {self.resource_name}") + print(f"resource: {type(self.resource).__name__}") + print(f"resource: {self.resource}") + print(f"parent: {self.parent}") + print(f"annotation: {self.annotation}") + print(f"optional: {self.optional}") + print(f"subdatatype: {self.subdatatype}") + + # -> cannot do that on dicts + # if self.parent is not None: + # print(f"_model_dump_excluded_: {self.parent.annotation._model_dump_excluded_}") + + if False: + print("------ STACK ------") + _rsrc = self.parent + while _rsrc is not None: + print(f"{id(_rsrc.annotation)}:{_rsrc.annotation}") + _rsrc = _rsrc.parent + print("-------------------") + + @classmethod + def init_sub(cls, walker: RestResourceWalker_Root) -> None: + pass + + @abstractmethod + def get_future(self) -> Optional[RestResourceWalkerFutureResult]: + return self.future_result + + def chain_process_future(self) -> Optional[TV_RestResourceWalkerFutureResult]: + if self.future_results_subs is not None and self.future_result is not None: + return_future_results_subs: list[Any] = [] # TODO: use typevar + for future_result in self.future_results_subs: + return_future_results_subs.append(future_result.chain_process_future()) + return self.future_result.process_future(return_future_results_subs) + return None + + def collect_future_result( + self, + process_future_result: Optional[RestResourceWalkerFutureResult[TV_RestResourceWalkerFutureResult]], + ) -> None: + if process_future_result is not None and self.future_results_subs is not None and self.future_result is not None: + self.future_results_subs.append(process_future_result) + + # @abstractmethod + def get_sub_resources(self) -> list[tuple[str, FieldInfo]]: + return [] + + def process(self): + pass + + @staticmethod + def ProcessAnnotation( + resource: FieldInfo | Type["RestResourceBase"], + ) -> tuple[type[Any], bool]: + from .rest_resource import RestResourceBase + + _anno: Type[Any] + + # print("!!!!!!!!!!!!!!!!!!!!!!!") + # print(resource) + # print(type(resource)) + + if isinstance(resource, FieldInfo) and resource.annotation is not None: + _anno = resource.annotation + elif not isinstance(resource, FieldInfo) and issubclass(resource, RestResourceBase): + _anno = resource + else: + raise RuntimeError("Incompatible resource type") + + _datatype = get_args(_anno) + _optional: bool = False + if get_origin(_anno) is Union: + if len(_datatype) == 2: + if _datatype[0] is type(None): + _anno = _datatype[1] + _optional = True + elif _datatype[1] is type(None): + _anno = _datatype[0] + _optional = True + else: + raise RuntimeError("Union is only allowed to describe Optional (e.g. Union[XXX,None])") + + return _anno, _optional + + +class RestResourceWalker_Sub_T_Dict(RestResourceWalker_Sub): + @classmethod + def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: + _anno, _optional = cls.ProcessAnnotation(resource) + _type_resource = get_origin(_anno) + return (_type_resource is dict), _anno, _optional + + def get_sub_resources(self) -> list[tuple[str, FieldInfo]]: + return [(self.resource_name, self.subdatatype[1])] + + def get_future(self) -> Optional[RestResourceWalkerFutureResult]: + return self.future_result + + +class RestResourceWalker_Sub_RestFields(RestResourceWalker_Sub): + @classmethod + def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: + _anno, _optional = cls.ProcessAnnotation(resource) + return (_anno in _T_SupportedRESTFields), _anno, _optional + + def get_future(self) -> Optional[RestResourceWalkerFutureResult]: + return self.future_result + + +class RestResourceWalker_Sub_RestResourceBase(RestResourceWalker_Sub): + @classmethod + def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: + from .rest_resource import RestResourceBase + + _anno, _optional = cls.ProcessAnnotation(resource) + return ( + ((get_origin(_anno) is None) and issubclass(_anno, RestResourceBase)), + _anno, + _optional, + ) + + def get_sub_resources(self) -> list[tuple[str, FieldInfo]]: + return [(cast(str, key), attr) for key, attr in self.annotation.model_fields.items()] + + def get_future(self) -> Optional[RestResourceWalkerFutureResult]: + return self.future_result + + +class RestResourceWalker_Root: + cls_RestResourceWalker_Sub: list[Type[RestResourceWalker_Sub]] = [ + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, + ] + + def __init__(self, resource: "RestResourceBase" | Type["RestResourceBase"]) -> None: + from .rest_resource import RestResourceBase + + self.resource: Type["RestResourceBase"] + if isinstance(resource, RestResourceBase): + self.resource = type(resource) + else: + self.resource = resource + + def process(self, deep_limit: Optional[int] = None) -> Optional[TV_RestResourceWalkerFutureResult]: + current_deep: int = 0 + for cls_Sub in self.cls_RestResourceWalker_Sub: + _self = self + cls_Sub.init_sub(_self) + + sub_walker_initial: Optional[RestResourceWalker_Sub] = RestResourceWalker_Sub.get( + self.cls_RestResourceWalker_Sub, "/", self.resource, None + ) + + if sub_walker_initial is not None: + sub_walker_initial.get_future() + resource_list: list[tuple[str, FieldInfo | Type["RestResourceBase"], RestResourceWalker_Sub]] = [ + (subresource_name, subresource, sub_walker_initial) + for subresource_name, subresource in sub_walker_initial.get_sub_resources() + ] + + new_resource_list: list[tuple[str, FieldInfo, RestResourceWalker_Sub]] + sub_walker: Optional[RestResourceWalker_Sub] + while len(resource_list) > 0 and (deep_limit is None or current_deep < deep_limit): + new_resource_list = [] + for resource_name, resource, parent_sub_walker in resource_list: + sub_walker = RestResourceWalker_Sub.get( + self.cls_RestResourceWalker_Sub, + resource_name, + resource, + parent_sub_walker, + ) + if sub_walker is not None: + sub_walker.process() + process_future_result: Optional[RestResourceWalkerFutureResult] = sub_walker.get_future() + parent_sub_walker.collect_future_result(process_future_result) + new_resource_list.extend( + [ + (subresource_name, subresource, sub_walker) + for subresource_name, subresource in sub_walker.get_sub_resources() + ] + ) + + resource_list = list(new_resource_list) + current_deep = current_deep + 1 + return sub_walker_initial.chain_process_future() + else: + raise RuntimeError("Invalid Rootpoint") + return None diff --git a/src/pyrestresource/rest_types.py b/src/pyrestresource/rest_types.py new file mode 100644 index 0000000..17a95bd --- /dev/null +++ b/src/pyrestresource/rest_types.py @@ -0,0 +1,105 @@ +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring +from __future__ import annotations +from enum import Enum, auto +from typing import Union, get_origin, NewType, TypeVar, TYPE_CHECKING +from datetime import datetime +from pathlib import Path +from uuid import UUID +from ipaddress import IPv4Address, IPv4Network + +if TYPE_CHECKING: + from .rest_resource import RestResourceBase + +T_Gen_DictKeys: type = type({}.keys()) +NoneType = type(None) + + +class rsrc_verb(Enum): + GET = auto() + PUT = auto() + POST = auto() + DELETE = auto() + + +class rsrc_type(Enum): + resource = auto() + dict = auto() + list = auto() + field = auto() + + +_T_SupportedRESTFields = [ + UUID, + str, + int, + float, + bool, + bytes, + datetime, + Path, + IPv4Address, + IPv4Network, +] +T_SupportedRESTFields = Union[ + UUID, str, int, float, bool, bytes, datetime, Path, IPv4Address, IPv4Network +] +TV_SupportedRESTFields = TypeVar( + "TV_SupportedRESTFields", + UUID, + str, + int, + float, + bool, + bytes, + datetime, + Path, + IPv4Address, + IPv4Network, +) + +if get_origin(T_SupportedRESTFields) is not Union: + raise RuntimeError("wrong T_SupportedRESTFields (must be flat Union)") + +TV_RestResourceBase = TypeVar("TV_RestResourceBase", bound="RestResourceBase") + +T_FieldValue = Union[T_SupportedRESTFields, "RestResourceBase"] + + +T_ListIndex = NewType("T_ListIndex", int) +T_ListSize = NewType("T_ListSize", int) + +T_DictKey = Union[ + UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network +] # datetime is removed because non-hashable +_T_DictKey = TypeVar( + "_T_DictKey", UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network +) + +T_T_DictKey = type[T_DictKey] + + +T_DictValues = T_FieldValue +_T_DictValues = TypeVar( + "_T_DictValues", + UUID, + str, + int, + float, + bool, + bytes, + datetime, + Path, + IPv4Address, + IPv4Network, + "RestResourceBase", +) + +T_T_FieldValue = type(T_FieldValue) +T_T_DictValues = type[T_DictValues] + +T_Dict = dict[T_DictKey, T_DictValues] +_T_Dict = dict[_T_DictKey, _T_DictValues] + +T_AllSupportedFields = T_Dict | T_FieldValue +T_AllSupportedFiels = T_Dict | T_FieldValue +T_AllSupportedContainers = Union[T_Dict, "RestResourceBase"] diff --git a/src/pyrestresource/test_module.py b/src/pyrestresource/test_module.py deleted file mode 100644 index e0275df..0000000 --- a/src/pyrestresource/test_module.py +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# pyrestresource (c) by chacha -# -# pyrestresource is licensed under a -# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. -# -# You should have received a copy of the license along with this -# work. If not, see . - -"""Phasellus tellus lectus, volutpat eu dapibus ut, suscipit vel augue. - -Tips: - Aliquam non leo vel libero sagittis viverra. Quisque lobortis nunc sit amet augue euismod laoreet. -Note: - Maecenas volutpat porttitor pretium. Aliquam suscipit quis nisi non imperdiet. -Note: - Vivamus et efficitur lorem, eget imperdiet tortor. Integer vel interdum sem. -""" - -from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: # Only imports the below statements during type checking - pass - -def test_function(testvar: int) -> int: - """ A test function that return testvar+1 and print "Hello world !" - - Proin eget sapien eget ipsum efficitur mollis nec ac nibh. - - Note: - Morbi id lectus maximus, condimentum nunc eget, porta felis. In tristique velit tortor. - - Args: - testvar: any integer - - Returns: - testvar+1 - """ - print("Hello world !") - return testvar+1 diff --git a/test/__init__.py b/test/__init__.py index 8a7f597..0aef653 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -4,4 +4,4 @@ # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. # # You should have received a copy of the license along with this -# work. If not, see . \ No newline at end of file +# work. If not, see . diff --git a/test/test_rest_login.py b/test/test_rest_login.py new file mode 100644 index 0000000..29114d1 --- /dev/null +++ b/test/test_rest_login.py @@ -0,0 +1,80 @@ +from __future__ import annotations +import unittest +from unittest.mock import patch +from os import chdir +from pathlib import Path +from typing import Optional, Annotated +from pydantic import Field +from uuid import UUID, uuid4 +from time import time +import json + +print(__name__) +print(__package__) + +from src.pyrestresource import ( + register_rest_rootpoint, + RestResourceBase, + rsrc_verb, + RestRequestParams_GET, + RestRequestParams_POST, + RestRequestParams_Dict_GET, + RestRequestParams_PUT, + T_SupportedRESTFields, + ResourcePlugin_field_default, + ResourcePlugin_RestResourceBase_default, +) +from pprint import pprint + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +# to allow mock-ing, all the tested classes are in a function +def init_classes(): + class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): + def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login: + print("hook GET") + print(resource) + print(params) + return resource + + def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login: + print("hook PUT") + print(resource) + print(params) + return resource + + class Login(RestResourceBase): + username: Optional[str] = Field(None, exclude=True) + # username: Optional[str] = Field(None) + secret: Optional[str] = Field(None, exclude=True) + + @register_rest_rootpoint + class RootApp(RestResourceBase): + login: Login = Field( + default=Login(), + plugin=ResourcePlugin_Login, + ) + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[Login.__name__] = Login + globals()[RootApp.__name__] = RootApp + + +class Test_RestAPI_LOGIN(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_login(self): + result = self.testapp.process_request("/login", rsrc_verb.GET) + print(result) + + result = self.testapp.process_request("/login", rsrc_verb.PUT, '{"username":"toto","secret":"123456"}') + print(result) + + result = self.testapp.process_request("/login", rsrc_verb.GET) + print(result) diff --git a/test/test_rest_resource.py b/test/test_rest_resource.py new file mode 100644 index 0000000..a54e0c8 --- /dev/null +++ b/test/test_rest_resource.py @@ -0,0 +1,557 @@ +from __future__ import annotations +import unittest +from unittest.mock import patch +from os import chdir +from pathlib import Path +from typing import Optional +from pydantic import Field +from uuid import UUID, uuid4 +from time import time +import json + + +print(__name__) +print(__package__) + +from src.pyrestresource import ( + register_rest_rootpoint, + RestResourceBase, + rsrc_verb, + RestRequestParams_GET, + RestRequestParams_POST, + RestRequestParams_Dict_GET, + T_SupportedRESTFields, +) +from pprint import pprint + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +# to allow mock-ing, all the tested classes are in a function +def init_classes(): + class Info(RestResourceBase): + version: str + api_version: str + + class Patch(RestResourceBase): + uuid: UUID = Field(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + class Profile(RestResourceBase): + uuid: UUID = Field(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + class Game(RestResourceBase): + uuid: UUID = Field(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + profiles: dict[UUID, Profile] = {} + patchs: dict[UUID, Patch] = {} + + Patch_1 = Patch(uuid="cee1e870-65fa-11ee-8c99-0242ac120002", shortname="testPatch1") + Patch_2 = Patch(uuid="d385a1d2-65fa-11ee-8c99-0242ac120002", shortname="testPatch2") + + class User(RestResourceBase): + uuid: UUID = Field(default_factory=uuid4, primary_key=True) + name: str + secret: str = Field(..., exclude=True) + + User1 = User( + uuid="8da57a3c-661f-11ee-8c99-0242ac120002", + name="chacha", + secret="la blanquette est bonne", + ) + + ext_patchs: dict[UUID, Patch] = {} + + class Patch2(RestResourceBase): + uuid: UUID = Field(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + @register_rest_rootpoint + class RootApp(RestResourceBase): + testValueRoot: float = 3.14 + info: Info = Info(version="0.0.1", api_version="0.0.2") + games: dict[UUID, Game] = { + UUID("9b0381d4-65f6-11ee-8c99-0242ac120002"): Game( + uuid="9b0381d4-65f6-11ee-8c99-0242ac120002", + shortname="testGame", + patchs={Patch_1.uuid: Patch_1}, + profiles={ + UUID("aee1e870-65fa-11ee-8c99-0242ac120002"): Profile( + uuid="aee1e870-65fa-11ee-8c99-0242ac120002", + shortname="testprofile", + ) + }, + ) + } + patchs: dict[UUID, Patch] = {Patch_1.uuid: Patch_1, Patch_2.uuid: Patch_2} + users: dict[UUID, User] = {User1.uuid: User1} + + patchs2: dict[UUID, Patch2] = {} + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[Info.__name__] = Info + globals()[Game.__name__] = Game + globals()[User.__name__] = User + globals()[Profile.__name__] = Profile + globals()[Patch.__name__] = Patch + globals()[Patch2.__name__] = Patch2 + globals()[RootApp.__name__] = RootApp + + +class Test_RestAPI_GET(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_get_root(self): + result = self.testapp.process_request("/", rsrc_verb.GET) + self.assertEqual(result, '{"testValueRoot": 3.14}') + + def test_get_root__multiple_slash(self): + result = self.testapp.process_request("/////", rsrc_verb.GET) + self.assertEqual(result, '{"testValueRoot": 3.14}') + + result = self.testapp.process_request("////", rsrc_verb.GET) + self.assertEqual(result, '{"testValueRoot": 3.14}') + + def test_get_root__nested_value(self): + result = self.testapp.process_request("/testValueRoot", rsrc_verb.GET) + self.assertEqual(result, "3.14") + + def test_get_root__nested_value__trailing_slash(self): + result = self.testapp.process_request("/testValueRoot/", rsrc_verb.GET) + self.assertEqual(result, "3.14") + + result = self.testapp.process_request("/testValueRoot//", rsrc_verb.GET) + self.assertEqual(result, "3.14") + + result = self.testapp.process_request("/testValueRoot///", rsrc_verb.GET) + self.assertEqual(result, "3.14") + + def test_get_root__nested_value__multiple_slash(self): + result = self.testapp.process_request("//testValueRoot", rsrc_verb.GET) + self.assertEqual(result, "3.14") + + result = self.testapp.process_request("///testValueRoot", rsrc_verb.GET) + self.assertEqual(result, "3.14") + + def test_get_version(self): + result = self.testapp.process_request("/info", rsrc_verb.GET) + self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + + def test_get_version__trailing_slash(self): + result = self.testapp.process_request("/info/", rsrc_verb.GET) + self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + + result = self.testapp.process_request("/info//", rsrc_verb.GET) + self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + + result = self.testapp.process_request("/info///", rsrc_verb.GET) + self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + + def test_get_version__multiple_slash(self): + result = self.testapp.process_request("//info", rsrc_verb.GET) + self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + + result = self.testapp.process_request("///info", rsrc_verb.GET) + self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + + def test_get_version__nested_value(self): + result = self.testapp.process_request("/info/api_version", rsrc_verb.GET) + self.assertEqual(result, '"0.0.2"') + + result = self.testapp.process_request("/info/version", rsrc_verb.GET) + self.assertEqual(result, '"0.0.1"') + + def test_get_dict_games(self): + result = self.testapp.process_request("/games", rsrc_verb.GET) + self.assertEqual(result, '["9b0381d4-65f6-11ee-8c99-0242ac120002"]') + + def test_get_dict_patchs(self): + result = self.testapp.process_request("/patchs", rsrc_verb.GET) + self.assertEqual( + result, + '["cee1e870-65fa-11ee-8c99-0242ac120002", "d385a1d2-65fa-11ee-8c99-0242ac120002"]', + ) + + def test_get_dict_patch_element(self): + result = self.testapp.process_request("/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", rsrc_verb.GET) + self.assertEqual( + result, + '{"uuid": "cee1e870-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch1", "name": null, "description": null}', + ) + + def test_get_dict_game_element(self): + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002", rsrc_verb.GET) + expected = '{"uuid": "9b0381d4-65f6-11ee-8c99-0242ac120002", "shortname": "testGame", "name": null, "description": null}' + self.assertEqual(result, expected) + + def test_get_dict_game_element__nested_value(self): + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname", rsrc_verb.GET) + expected = '"testGame"' + self.assertEqual(result, expected) + + def test_get_dict_game_element__nested_value2(self): + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/uuid", rsrc_verb.GET) + expected = '"9b0381d4-65f6-11ee-8c99-0242ac120002"' + self.assertEqual(result, expected) + + def test_get_nested_dict_games_patchs(self): + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) + self.assertEqual(result, '["cee1e870-65fa-11ee-8c99-0242ac120002"]') + + def test_get_nested_dict_games_patch_element(self): + result = self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.GET, + ) + expected = '{"uuid": "cee1e870-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch1", "name": null, "description": null}' + self.assertEqual(result, expected) + + def test_get_nested_dict_games_patch_element__nested_value(self): + result = self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/uuid", + rsrc_verb.GET, + ) + self.assertEqual(result, '"cee1e870-65fa-11ee-8c99-0242ac120002"') + + def test_get_dict_game_element__API_nested(self): + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002?API_nested=True", rsrc_verb.GET) + expected = '{"uuid": "9b0381d4-65f6-11ee-8c99-0242ac120002", "shortname": "testGame", "name": null, "description": null}' + self.assertEqual(result, expected) + + def test_get_dict_users(self): + result = self.testapp.process_request("/users", rsrc_verb.GET) + self.assertEqual(result, '["8da57a3c-661f-11ee-8c99-0242ac120002"]') + + def test_get_dict_user_element(self): + result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.GET) + self.assertEqual( + result, + '{"uuid": "8da57a3c-661f-11ee-8c99-0242ac120002", "name": "chacha"}', + "no secret seen", + ) + + def test_get_dict_user_element2(self): + result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002?API_nested=True", rsrc_verb.GET) + self.assertEqual( + result, + '{"uuid": "8da57a3c-661f-11ee-8c99-0242ac120002", "name": "chacha"}', + "no secret seen", + ) + + def test_get_dict_user_element__nested_value(self): + result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002/name", rsrc_verb.GET) + self.assertEqual(result, '"chacha"') + + def test_get_dict_user_element__nested_value__forbiden(self): + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002/secret", rsrc_verb.GET) + + def test_get_dict_user_element__nested_value__forbiden2(self): + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request( + "/users/8da57a3c-661f-11ee-8c99-0242ac120002/secret?API_nested=True", + rsrc_verb.GET, + ) + + +class Test_RestAPI_PUT(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_put_info(self): + self.testapp.process_request("/info", rsrc_verb.PUT, '{"version": "1.2.3", "api_version": "3.2.1"}') + + result = self.testapp.process_request("/info", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.2.3", "api_version": "3.2.1"}') + + def test_put_dict_user_nested_value(self): + self.testapp.process_request( + "/users/8da57a3c-661f-11ee-8c99-0242ac120002/name", + rsrc_verb.PUT, + '"chacha2"', + ) + + result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002/name", rsrc_verb.GET) + self.assertEqual(result, '"chacha2"') + + def test_put_user_nested_value__forbiden(self): + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request( + "/users/8da57a3c-661f-11ee-8c99-0242ac120002/secret", + rsrc_verb.PUT, + '"test"', + ) + + def test_put_dict_user_element(self): + self.testapp.process_request( + "/users/8da57a3c-661f-11ee-8c99-0242ac120002", + rsrc_verb.PUT, + '{"name": "testUser4", "secret": "test5"}', + ) + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = '["8da57a3c-661f-11ee-8c99-0242ac120002"]' + self.assertEqual(result, expected) + + result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.GET) + expected = '{"uuid": "8da57a3c-661f-11ee-8c99-0242ac120002", "name": "testUser4"}' + self.assertEqual(result, expected) + + def test_put_dict_patch__nested(self): + self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.PUT, + '{"shortname": "testPatch998", "name": "MyPatch", "description": "MyDescription123"}', + ) + + result = self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.GET, + ) + expected = '{"uuid": "cee1e870-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch998", "name": "MyPatch", "description": "MyDescription123"}' + self.assertEqual(result, expected) + + +class Test_RestAPI_POST(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_post_dict_user__API_key(self): + result = self.testapp.process_request( + "/users?API_key=e5e87d32-662b-11ee-8c99-0242ac120002", + rsrc_verb.POST, + '{"name": "testUser", "secret": "test"}', + ) + self.assertEqual(result, '"e5e87d32-662b-11ee-8c99-0242ac120002"') + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "e5e87d32-662b-11ee-8c99-0242ac120002"]' + self.assertEqual(result, expected) + + result = self.testapp.process_request("/users/e5e87d32-662b-11ee-8c99-0242ac120002", rsrc_verb.GET) + expected = '{"uuid": "e5e87d32-662b-11ee-8c99-0242ac120002", "name": "testUser"}' + self.assertEqual(result, expected) + + def test_post_dict_user__nested_key(self): + result = self.testapp.process_request( + "/users", + rsrc_verb.POST, + '{"name": "testUser2", "secret": "test", "uuid":"e7e86d32-662b-11ee-8c99-0242ac120002"}', + ) + self.assertEqual(result, '"e7e86d32-662b-11ee-8c99-0242ac120002"') + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "e7e86d32-662b-11ee-8c99-0242ac120002"]' + self.assertEqual(result, expected) + + result = self.testapp.process_request("/users/e7e86d32-662b-11ee-8c99-0242ac120002", rsrc_verb.GET) + expected = '{"uuid": "e7e86d32-662b-11ee-8c99-0242ac120002", "name": "testUser2"}' + self.assertEqual(result, expected) + + @patch(f"{__loader__.name }.uuid4") + def test_post_dict_user__auto_key(self, mock_uuid4): + mock_uuid4.return_value = UUID("5faccb2e-69aa-11ee-8c99-0242ac120002") + + # recreating classes & objects to force using the Mock-ed uuid4 + init_classes() + self.testapp = RootApp() + + result = self.testapp.process_request("/users", rsrc_verb.POST, '{"name": "testUser3", "secret": "test"}') + self.assertEqual(result, '"5faccb2e-69aa-11ee-8c99-0242ac120002"') + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "5faccb2e-69aa-11ee-8c99-0242ac120002"]' + self.assertEqual(result, expected) + + result = self.testapp.process_request("/users/5faccb2e-69aa-11ee-8c99-0242ac120002", rsrc_verb.GET) + expected = '{"uuid": "5faccb2e-69aa-11ee-8c99-0242ac120002", "name": "testUser3"}' + self.assertEqual(result, expected) + + def test_post_dict_patch__nested_API_key(self): + self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs?API_key=cee1e971-65fa-11ee-8c99-0242ac120002", + rsrc_verb.POST, + '{"shortname": "testPatch99", "name": "MyPatch", "description": "MyDescription"}', + ) + + result = self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e971-65fa-11ee-8c99-0242ac120002", + rsrc_verb.GET, + ) + expected = '{"uuid": "cee1e971-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch99", "name": "MyPatch", "description": "MyDescription"}' + self.assertEqual(result, expected) + + +class Test_RestAPI_DELETE(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_delete_dict_user__API_key(self): + self.testapp.process_request("/users?API_key=8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.DELETE) + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result, expected) + + def test_delete_dict_user__All(self): + result = self.testapp.process_request( + "/users?API_key=e5e87d32-662b-11ee-8c99-0242ac120002", + rsrc_verb.POST, + '{"name": "testUser", "secret": "test"}', + ) + self.assertEqual(result, '"e5e87d32-662b-11ee-8c99-0242ac120002"') + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "e5e87d32-662b-11ee-8c99-0242ac120002"]' + self.assertEqual(result, expected) + + self.testapp.process_request("/users", rsrc_verb.DELETE) + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result, expected) + + def test_delete_dict_user_element(self): + self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.DELETE) + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result, expected) + + def test_delete_nested_dict_games_patch_element(self): + self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.DELETE, + ) + + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result, expected) + + def test_delete_nested_dict_games_patch_API_key(self): + self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs?API_key=cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.DELETE, + ) + + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result, expected) + + def test_delete_nested_dict_games_patch_All(self): + self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.DELETE) + + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result, expected) + + +class Test_RestAPI_PERFO(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + @unittest.skip + def test_perf_dict(self): + print(f"LIB INTERNAL PERF TEST") + n_loop = 10000 + + start = time() + for i in range(n_loop): + self.testapp.process_request(f"/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.GET) + end = time() + print(f"GET 1st level dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for i in range(n_loop): + newUUID = uuid4() + self.testapp.process_request( + f"/users?API_key={newUUID}", + rsrc_verb.POST, + '{"name": "testUser", "secret": "test"}', + ) + end = time() + print(f"POST 1st level dict (API_key): {int(n_loop/(end-start))} Req/s") + + start = time() + for i in range(n_loop): + newUUID = uuid4() + self.testapp.process_request( + f"/users?API_key={newUUID}", + rsrc_verb.POST, + '{"name": "testUser", "secret": "test"}', + ) + self.testapp.process_request(f"/users/{newUUID}", rsrc_verb.GET) + end = time() + print(f"POST/GET 1st level dict (API_key): {int(n_loop/(end-start))} Req/s") + + start = time() + for i in range(n_loop): + result = self.testapp.process_request(f"/users", rsrc_verb.POST, '{"name": "testUser", "secret": "test"}') + self.testapp.process_request(f"/users/{json.loads(result)}", rsrc_verb.GET) + end = time() + print(f"POST/GET 1st level dict (autokey): {int(n_loop/(end-start))} Req/s") + + start = time() + for i in range(n_loop): + self.testapp.process_request( + f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname", + rsrc_verb.PUT, + '"TestValue!!"', + ) + self.testapp.process_request(f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname", rsrc_verb.GET) + end = time() + print(f"PUT/GET 1st level (value) dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for i in range(n_loop): + self.testapp.process_request( + f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.GET, + ) + end = time() + print(f"GET 2nd level dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for i in range(n_loop): + self.testapp.process_request( + f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + rsrc_verb.GET, + ) + end = time() + print(f"GET 2nd level (value) dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for i in range(n_loop): + self.testapp.process_request( + f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + rsrc_verb.PUT, + '"TestValue!!"', + ) + self.testapp.process_request( + f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + rsrc_verb.GET, + ) + end = time() + print(f"PUT/GET 2nd level (value) dict: {int(n_loop/(end-start))} Req/s") diff --git a/test/test_rest_resource_plugins.py b/test/test_rest_resource_plugins.py new file mode 100644 index 0000000..ac40326 --- /dev/null +++ b/test/test_rest_resource_plugins.py @@ -0,0 +1,216 @@ +from __future__ import annotations +import unittest +from unittest.mock import patch +from os import chdir +from pathlib import Path +from typing import Optional, Annotated +from pydantic import Field +from uuid import UUID, uuid4 +from time import time +import json + +print(__name__) +print(__package__) + +from src.pyrestresource import ( + register_rest_rootpoint, + RestResourceBase, + rsrc_verb, + RestRequestParams_GET, + RestRequestParams_POST, + RestRequestParams_Dict_GET, + RestRequestParams_PUT, + T_SupportedRESTFields, + ResourcePlugin_field_default, + ResourcePlugin_RestResourceBase_default, +) +from pprint import pprint + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +# to allow mock-ing, all the tested classes are in a function +def init_classes(): + class ResourcePlugin_version_get(ResourcePlugin_field_default): + def handle_field_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get: + return "1.5.6" + + class ResourcePlugin_version_put(ResourcePlugin_field_default): + def handle_field_put(self, resource: Info_put, params: RestRequestParams_PUT) -> Info_put: + return "42" + + class ResourcePlugin_Info(ResourcePlugin_RestResourceBase_default): + def handle_resource_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get: + return Info_get(version="65.45", api_version="98.321") + + class Info_get(RestResourceBase): + # test plugin injection within annotation + # + test plugin on a simple field + version: Annotated[str, Field(plugin=ResourcePlugin_version_get)] + api_version: str + + class Info_put(RestResourceBase): + # test plugin injection within annotation + # + test plugin on a simple field + version: Annotated[str, Field(plugin=ResourcePlugin_version_put)] + api_version: str + + @register_rest_rootpoint + class RootApp(RestResourceBase): + # test plugin injection within Field value + # + test plugin on a RestResourceBase field + info: Info_get = Field( + default=Info_get(version="0.0.1", api_version="0.0.2"), + plugin=ResourcePlugin_Info, + ) + info_put: Info_put = Field( + default=Info_put(version="0.0.1", api_version="0.0.2"), + ) + info2: Info_get = Field(default=Info_get(version="0.0.2", api_version="0.0.3")) + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[Info_get.__name__] = Info_get + globals()[Info_put.__name__] = Info_put + globals()[RootApp.__name__] = RootApp + + +def init_bad_plugin1(): + # plugin with missing handle_resource_put() method + class ResourcePlugin_TestResource: + def handle_field_get(self, resource: TestResource, params: RestRequestParams_GET) -> TestResource: + return resource + + class TestResource(RestResourceBase): + tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] + + @register_rest_rootpoint + class RootApp2(RestResourceBase): + test: TestResource = Field(default=TestResource(tetvaluestr="testvalue")) + + RootApp2() + + +def init_bad_plugin2(): + # plugin with missing handle_resource_get() method + class ResourcePlugin_TestResource: + def handle_field_put(self, resource: TestResource, params: RestRequestParams_PUT) -> TestResource: + return resource + + class TestResource(RestResourceBase): + tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] + + @register_rest_rootpoint + class RootApp2(RestResourceBase): + test: TestResource = Field(default=TestResource(tetvaluestr="testvalue")) + + RootApp2() + + +def init_bad_plugin3(): + # wrong plugin + class ResourcePlugin_TestResource(ResourcePlugin_RestResourceBase_default): + pass + + class TestResource(RestResourceBase): + tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] + + @register_rest_rootpoint + class RootApp2(RestResourceBase): + test: TestResource = Field(default=TestResource(tetvaluestr="testvalue")) + + RootApp2() + + +class Test_RestAPI_Plugin_PUT(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_put_field_version_fieldplugin(self): + self.testapp.process_request("/info_put/version", rsrc_verb.PUT, '"1.5.6"') + + result = self.testapp.process_request("/info_put", rsrc_verb.GET) + print(result) + result = self.testapp.process_request("/info_put/version", rsrc_verb.GET) + print(result) + self.assertEqual(result, '"42"') + + def test_put_field_version_resourceplugin(self): + self.testapp.process_request("/info_put", rsrc_verb.PUT, '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("/info_put", rsrc_verb.GET) + self.assertEqual(result, '{"version": "42", "api_version": "98.321"}') + + +class Test_RestAPI_Plugin_GET(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_get_root(self): + result = self.testapp.process_request("/", rsrc_verb.GET) + self.assertEqual(result, "{}") + + def test_get_version(self): + result = self.testapp.process_request("/info", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("/info2", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + + def test_get_version__trailing_slash(self): + result = self.testapp.process_request("/info/", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("/info//", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("/info///", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("/info2/", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + + result = self.testapp.process_request("/info2//", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + + result = self.testapp.process_request("/info2///", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + + def test_get_version__multiple_slash(self): + result = self.testapp.process_request("//info", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("///info", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("//info2", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + + result = self.testapp.process_request("///info2", rsrc_verb.GET) + self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + + def test_get_version__nested_value(self): + result = self.testapp.process_request("/info/api_version", rsrc_verb.GET) + self.assertEqual(result, '"98.321"') + + result = self.testapp.process_request("/info/version", rsrc_verb.GET) + self.assertEqual(result, '"1.5.6"') + + result = self.testapp.process_request("/info2/api_version", rsrc_verb.GET) + self.assertEqual(result, '"0.0.3"') + + result = self.testapp.process_request("/info2/version", rsrc_verb.GET) + self.assertEqual(result, '"1.5.6"') + + def test_defect_plugin_field(self): + with self.assertRaises(RuntimeError): + init_bad_plugin1() + with self.assertRaises(RuntimeError): + init_bad_plugin2() + with self.assertRaises(RuntimeError): + init_bad_plugin3() diff --git a/test/test_rest_resource_walker.py b/test/test_rest_resource_walker.py new file mode 100644 index 0000000..a467e37 --- /dev/null +++ b/test/test_rest_resource_walker.py @@ -0,0 +1,163 @@ +from __future__ import annotations +import unittest + +from typing import Annotated, Optional + +from os import chdir +from pathlib import Path +from pydantic import Field +from io import StringIO +from contextlib import redirect_stdout, redirect_stderr + +print(__name__) +print(__package__) + +from src.pyrestresource import ( + RestResourceBase, +) + +from src.pyrestresource.rest_resource_walker import ( + RestResourceWalker_Root, + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, +) + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +class RestResourceWalker_Sub_T_Dict_TEST_Print(RestResourceWalker_Sub_T_Dict): + counter: dict[str, int] = {} + + @classmethod + def init_sub(cls, walker: RestResourceWalker_Root) -> None: + cls.counter = {} + + def process(self) -> None: + if self.resource_name not in self.counter: + self.counter[self.resource_name] = 0 + self.counter[self.resource_name] = self.counter[self.resource_name] + 1 + + print(f"DICT {self.resource_name} {self.counter[self.resource_name]}") + + +class RestResourceWalker_Sub_RestFields_TEST_Print(RestResourceWalker_Sub_RestFields): + counter: dict[str, int] = {} + + @classmethod + def init_sub(cls, walker: RestResourceWalker_Root) -> None: + cls.counter = {} + + def process(self) -> None: + if self.resource_name not in self.counter: + self.counter[self.resource_name] = 0 + self.counter[self.resource_name] = self.counter[self.resource_name] + 1 + + print(f"FIELD {self.resource_name} {self.counter[self.resource_name]}") + + +class RestResourceWalker_Sub_RestResourceBase_TEST_Print( + RestResourceWalker_Sub_RestResourceBase +): + counter: dict[str, int] = {} + + @classmethod + def init_sub(cls, walker: RestResourceWalker_Root) -> None: + cls.counter = {} + + def process(self) -> None: + if self.resource_name not in self.counter: + self.counter[self.resource_name] = 0 + self.counter[self.resource_name] = self.counter[self.resource_name] + 1 + + print(f"RestResource {self.resource_name} {self.counter[self.resource_name]}") + + +class RestResourceWalker_Root_TEST_Print(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict_TEST_Print, + RestResourceWalker_Sub_RestFields_TEST_Print, + RestResourceWalker_Sub_RestResourceBase_TEST_Print, + ] + + +def init_classes(): + class Info(RestResourceBase): + version: str + api_version: str + + class People(RestResourceBase): + last_name: str + + class RootApp(RestResourceBase): + info: Info = Field(default=Info(version="0.0.1", api_version="0.0.2")) + info2: Info = Info(version="0.0.2", api_version="0.0.3") + peoples: dict[str, People] = { + "john": People(last_name="Doe"), + "jane": People(last_name="Roe"), + } + test_string: str = "test value" + test_string_opt: Optional[str] = None + test_int: int = 42 + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[Info.__name__] = Info + globals()[People.__name__] = People + globals()[RootApp.__name__] = RootApp + + +class Test_Walker(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + + def test_walk_class(self): + test = RestResourceWalker_Root_TEST_Print(RootApp) + with redirect_stdout(StringIO()) as capted_stdout, redirect_stderr( + StringIO() + ) as capted_stderr: + test.process() + self.assertIn("RestResource info 1", capted_stdout.getvalue()) + self.assertIn("RestResource info2 1", capted_stdout.getvalue()) + self.assertIn("DICT peoples 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_string 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_string_opt 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_int 1", capted_stdout.getvalue()) + self.assertIn("FIELD version 1", capted_stdout.getvalue()) + self.assertIn("FIELD version 2", capted_stdout.getvalue()) + self.assertIn("FIELD api_version 1", capted_stdout.getvalue()) + self.assertIn("FIELD api_version 2", capted_stdout.getvalue()) + self.assertIn("RestResource peoples 1", capted_stdout.getvalue()) + self.assertIn("FIELD last_name 1", capted_stdout.getvalue()) + + def test_walk_obj(self): + instRootApp = RootApp() + test = RestResourceWalker_Root_TEST_Print(instRootApp) + with redirect_stdout(StringIO()) as capted_stdout, redirect_stderr( + StringIO() + ) as capted_stderr: + test.process() + self.assertIn("RestResource info 1", capted_stdout.getvalue()) + self.assertIn("RestResource info2 1", capted_stdout.getvalue()) + self.assertIn("DICT peoples 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_string 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_string_opt 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_int 1", capted_stdout.getvalue()) + self.assertIn("FIELD version 1", capted_stdout.getvalue()) + self.assertIn("FIELD version 2", capted_stdout.getvalue()) + self.assertIn("FIELD api_version 1", capted_stdout.getvalue()) + self.assertIn("FIELD api_version 2", capted_stdout.getvalue()) + self.assertIn("RestResource peoples 1", capted_stdout.getvalue()) + self.assertIn("FIELD last_name 1", capted_stdout.getvalue()) + + def test_walk_obj_nested_RestResource(self): + instRootApp = RootApp() + test = RestResourceWalker_Root_TEST_Print(instRootApp.info) + with redirect_stdout(StringIO()) as capted_stdout, redirect_stderr( + StringIO() + ) as capted_stderr: + test.process() + self.assertIn("FIELD version 1", capted_stdout.getvalue()) + self.assertIn("FIELD api_version 1", capted_stdout.getvalue()) diff --git a/test/test_rest_resource_walker_tree.py b/test/test_rest_resource_walker_tree.py new file mode 100644 index 0000000..8ae2a48 --- /dev/null +++ b/test/test_rest_resource_walker_tree.py @@ -0,0 +1,150 @@ +from __future__ import annotations +import unittest + +from typing import Annotated, Optional + +from os import chdir +from pathlib import Path +from pydantic import Field +from io import StringIO +from contextlib import redirect_stdout, redirect_stderr + +print(__name__) +print(__package__) + +from src.pyrestresource import ( + RestResourceBase, +) + +from src.pyrestresource.rest_resource_walker import ( + RestResourceWalkerFutureResult, + RestResourceWalker_Root, + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, +) + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +class RestResourceWalkerFutureResult_RestFields_Test(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + res = dict() + res[self.source.resource_name] = False + return res + + +class RestResourceWalker_Sub_RestFields_TEST_Print(RestResourceWalker_Sub_RestFields): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_RestFields_Test + + +class RestResourceWalkerFutureResult_RestResourceBase_Test(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + res = dict() + res[self.source.resource_name] = dict() + for subres in result: + res[self.source.resource_name] = res[self.source.resource_name] | subres + return res + + +class RestResourceWalker_Sub_RestResourceBase_TEST_Print(RestResourceWalker_Sub_RestResourceBase): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_RestResourceBase_Test + + +class RestResourceWalkerFutureResult_Dict_Test(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + res = dict() + for subres in result: + res = res | subres + return res + + +class RestResourceWalker_Sub_T_Dict_TEST_Print(RestResourceWalker_Sub_T_Dict): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_Dict_Test + + +class RestResourceWalker_Root_TEST_Print(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict_TEST_Print, + RestResourceWalker_Sub_RestFields_TEST_Print, + RestResourceWalker_Sub_RestResourceBase_TEST_Print, + ] + + +def init_classes(): + class Info(RestResourceBase): + version: str + api_version: str + + class People(RestResourceBase): + last_name: str + + class RootApp(RestResourceBase): + info: Info = Field(default=Info(version="0.0.1", api_version="0.0.2")) + info2: Info = Info(version="0.0.2", api_version="0.0.3") + peoples: dict[str, People] = { + "john": People(last_name="Doe"), + "jane": People(last_name="Roe"), + } + test_string: str = "test value" + test_string_opt: Optional[str] = None + test_int: int = 42 + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[Info.__name__] = Info + globals()[People.__name__] = People + globals()[RootApp.__name__] = RootApp + + +class Test_Walker_tree(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + + def test_walk_class(self): + test = RestResourceWalker_Root_TEST_Print(RootApp) + res = test.process() + self.assertDictEqual( + res, + { + "/": { + "info": {"version": False, "api_version": False}, + "info2": {"version": False, "api_version": False}, + "peoples": {"last_name": False}, + "test_string": False, + "test_string_opt": False, + "test_int": False, + } + }, + ) + + def test_walk_obj(self): + instRootApp = RootApp() + test = RestResourceWalker_Root_TEST_Print(instRootApp) + res = test.process() + self.assertDictEqual( + res, + { + "/": { + "info": {"version": False, "api_version": False}, + "info2": {"version": False, "api_version": False}, + "peoples": {"last_name": False}, + "test_string": False, + "test_string_opt": False, + "test_int": False, + } + }, + ) + + def test_walk_obj_nested_RestResource(self): + instRootApp = RootApp() + test = RestResourceWalker_Root_TEST_Print(instRootApp.info) + res = test.process() + self.assertDictEqual( + res, + { + "/": {"version": False, "api_version": False}, + }, + ) diff --git a/test/test_rest_webserver.py b/test/test_rest_webserver.py new file mode 100644 index 0000000..28edf77 --- /dev/null +++ b/test/test_rest_webserver.py @@ -0,0 +1,382 @@ +from __future__ import annotations +import unittest +from unittest.mock import patch +from os import chdir +from pathlib import Path +from typing import Optional +from pydantic import Field +from uuid import UUID, uuid4 +from time import time, sleep +import json +import uvicorn +import socket +import requests +from contextlib import closing +from multiprocessing import Process + + +print(__name__) +print(__package__) + +from src.pyrestresource import ( + register_rest_rootpoint, + RestResourceBase, + rsrc_verb, + RestRequestParams_GET, + RestRequestParams_POST, + RestRequestParams_Dict_GET, + T_SupportedRESTFields, +) +from pprint import pprint + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +# to allow mock-ing, all the tested classes are in a function +def init_classes(): + class Info(RestResourceBase): + version: str + api_version: str + + class Patch(RestResourceBase): + uuid: UUID = Field(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + class Profile(RestResourceBase): + uuid: UUID = Field(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + class Game(RestResourceBase): + uuid: UUID = Field(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + profiles: dict[UUID, Profile] = {} + patchs: dict[UUID, Patch] = {} + + Patch_1 = Patch(uuid="cee1e870-65fa-11ee-8c99-0242ac120002", shortname="testPatch1") + Patch_2 = Patch(uuid="d385a1d2-65fa-11ee-8c99-0242ac120002", shortname="testPatch2") + + class User(RestResourceBase): + uuid: UUID = Field(default_factory=uuid4, primary_key=True) + name: str + secret: str = Field(..., exclude=True) + + User1 = User( + uuid="8da57a3c-661f-11ee-8c99-0242ac120002", + name="chacha", + secret="la blanquette est bonne", + ) + + class Patch2(RestResourceBase): + uuid: UUID = Field(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + @register_rest_rootpoint + class RootApp(RestResourceBase): + testValueRoot: float = 3.14 + info: Info = Info(version="0.0.1", api_version="0.0.2") + games: dict[UUID, Game] = { + UUID("9b0381d4-65f6-11ee-8c99-0242ac120002"): Game( + uuid="9b0381d4-65f6-11ee-8c99-0242ac120002", + shortname="testGame Origin", + description="test Game Desc Origin", + patchs={Patch_1.uuid: Patch_1}, + profiles={ + UUID("aee1e870-65fa-11ee-8c99-0242ac120002"): Profile( + uuid="aee1e870-65fa-11ee-8c99-0242ac120002", + shortname="testprofile", + ) + }, + ) + } + patchs: dict[UUID, Patch] = {Patch_1.uuid: Patch_1, Patch_2.uuid: Patch_2} + users: dict[UUID, User] = {User1.uuid: User1} + + patchs2: dict[UUID, Patch2] = {} + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[Info.__name__] = Info + globals()[Game.__name__] = Game + globals()[User.__name__] = User + globals()[Profile.__name__] = Profile + globals()[Patch.__name__] = Patch + globals()[Patch2.__name__] = Patch2 + globals()[RootApp.__name__] = RootApp + + +def find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + hostname = socket.gethostname() + IPAddr = socket.gethostbyname(hostname) + return "localhost", s.getsockname()[1] + + +def launch_server(ip, port): + print(f"port2={port}") + init_classes() + uvicorn.run(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True) + + +class Test_RestAPI_WebServer(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + + def test_nomal_AllCmd_games(self): + ip, port = find_free_port() + print(f"ip1={ip}") + print(f"port1={port}") + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + try: + # Fetching games + response = s.get(f"http://{ip}:{port}/games") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIsInstance(data, list) + self.assertEqual( + data, + ["9b0381d4-65f6-11ee-8c99-0242ac120002"], + ) + + # Login in + """ + response = s.post( + f"http://{ip}:{port}/login", + params={"username": "test", "password": "test"}, + ) + self.assertEqual(response.status_code, 200) + """ + + # Add a new one (with all values setted) + response = s.post( + f"http://{ip}:{port}/games", + json={ + "shortname": "test", + "name": "nametest", + "description": "test Game Desc", + }, + ) + self.assertEqual(response.status_code, 201) + data = response.json() + NEW_GAME_UUID = UUID(data) + + # Fetching games again + response = s.get(f"http://{ip}:{port}/games") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIsInstance(data, list) + self.assertEqual( + data, + ["9b0381d4-65f6-11ee-8c99-0242ac120002", str(NEW_GAME_UUID)], + ) + + # Getting accurate values of created element + response = s.get(f"http://{ip}:{port}/games/{str(NEW_GAME_UUID)}") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIsInstance(data, dict) + self.assertIsInstance(data, dict) + self.assertIn("shortname", data) + self.assertIn("name", data) + self.assertIn("description", data) + self.assertIn("uuid", data) + NEW_GAME_UUID = UUID(data["uuid"]) + del data["uuid"] + self.assertDictEqual( + data, + { + "name": "nametest", + "shortname": "test", + "description": "test Game Desc", + }, + ) + + # removing the new one + response = s.delete(f"http://{ip}:{port}/games/{str(NEW_GAME_UUID)}") + self.assertEqual(response.status_code, 200) + + # Fetching games again + response = s.get(f"http://{ip}:{port}/games") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual( + data, + ["9b0381d4-65f6-11ee-8c99-0242ac120002"], + ) + + # Getting accurate values + response = s.get(f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIsInstance(data, dict) + self.assertIsInstance(data, dict) + self.assertIn("shortname", data) + self.assertIn("name", data) + self.assertIn("description", data) + self.assertIn("uuid", data) + NEW_GAME_UUID = UUID(data["uuid"]) + del data["uuid"] + self.assertDictEqual( + data, + { + "name": None, + "shortname": "testGame Origin", + "description": "test Game Desc Origin", + }, + ) + + # Update values + response = s.put( + f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002", + json={ + "name": "MyName", + }, + ) + self.assertEqual(response.status_code, 201) + + # Getting accurate values + response = s.get(f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIsInstance(data, dict) + self.assertIsInstance(data, dict) + self.assertIn("shortname", data) + self.assertIn("name", data) + self.assertIn("description", data) + self.assertIn("uuid", data) + NEW_GAME_UUID = UUID(data["uuid"]) + del data["uuid"] + self.assertDictEqual( + data, + { + "name": "MyName", + "shortname": "testGame Origin", + "description": "test Game Desc Origin", + }, + ) + + # removing original element + response = s.delete(f"http://{ip}:{port}/games?API_key={str(NEW_GAME_UUID)}") + self.assertEqual(response.status_code, 200) + + # Fetching games again + response = s.get(f"http://{ip}:{port}/games") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertTrue(len(data) == 0) + finally: + proc.terminate() + s.close() + + @unittest.skip + def test_perf_dict(self): + print(f"SOCKET PERF TEST") + n_loop = 10000 + + ip, port = find_free_port() + print(f"ip1={ip}") + print(f"port1={port}") + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + try: + start = time() + for _ in range(n_loop): + s.get(f"http://{ip}:{port}/users/8da57a3c-661f-11ee-8c99-0242ac120002") + end = time() + print(f"GET 1st level dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + newUUID = uuid4() + s.post( + f"http://{ip}:{port}/users?API_key={newUUID}", + json={"name": "testUser", "secret": "test"}, + ) + end = time() + + print(f"POST 1st level dict (API_key): {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + newUUID = uuid4() + s.post( + f"http://{ip}:{port}/users?API_key={str(newUUID)}", + json={"name": "testUser", "secret": "test"}, + ) + s.get(f"http://{ip}:{port}/users/{newUUID}") + end = time() + print(f"POST/GET 1st level dict (API_key): {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + response = s.post(f"http://{ip}:{port}/users", '{"name": "testUser", "secret": "test"}') + s.get(f"http://{ip}:{port}/users/{response.json()}") + end = time() + print(f"POST/GET 1st level dict (autokey): {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + s.put( + f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname", + json="TestValue!!", + ) + s.get(f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname") + end = time() + print(f"PUT/GET 1st level (value) dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + s.get(f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002") + end = time() + print(f"GET 2nd level dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + s.get( + f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + ) + end = time() + print(f"GET 2nd level (value) dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + s.put( + f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + json="TestValue!!", + ) + s.get( + f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + ) + end = time() + print(f"PUT/GET 2nd level (value) dict: {int(n_loop/(end-start))} Req/s") + + finally: + proc.terminate() + s.close() diff --git a/test/test_test_module.py b/test/test_test_module.py deleted file mode 100644 index 9f4280d..0000000 --- a/test/test_test_module.py +++ /dev/null @@ -1,35 +0,0 @@ -# pyrestresource (c) by chacha -# -# pyrestresource is licensed under a -# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. -# -# You should have received a copy of the license along with this -# work. If not, see . - -import unittest -from os import chdir -from io import StringIO -from contextlib import redirect_stdout,redirect_stderr -from pathlib import Path - -print(__name__) -print(__package__) - -from src import pyrestresource - -testdir_path = Path(__file__).parent.resolve() -chdir(testdir_path.parent.resolve()) - -class Testtest_module(unittest.TestCase): - def setUp(self) -> None: - chdir(testdir_path.parent.resolve()) - - def test_version(self): - self.assertNotEqual(pyrestresource.__version__,"?.?.?") - - def test_test_module(self): - - with redirect_stdout(StringIO()) as capted_stdout, redirect_stderr(StringIO()) as capted_stderr: - self.assertEqual(pyrestresource.test_function(41),42) - self.assertEqual(len(capted_stderr.getvalue()),0) - self.assertEqual(capted_stdout.getvalue().strip(),"Hello world !") \ No newline at end of file -- 2.47.3 From 6a554af8f8bf4ef3b2bd85e542a9d7153fb588bb Mon Sep 17 00:00:00 2001 From: cclecle Date: Wed, 1 Nov 2023 00:02:54 +0000 Subject: [PATCH 02/20] fix dockerfile --- .settings/org.eclipse.core.resources.prefs | 1 + Dockerfile | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.settings/org.eclipse.core.resources.prefs b/.settings/org.eclipse.core.resources.prefs index f0456ba..59b1ede 100644 --- a/.settings/org.eclipse.core.resources.prefs +++ b/.settings/org.eclipse.core.resources.prefs @@ -1,3 +1,4 @@ eclipse.preferences.version=1 encoding//src/pyrestresource/__init__.py=utf-8 +encoding//src/pyrestresource/__metadata__.py=utf-8 encoding/=UTF-8 diff --git a/Dockerfile b/Dockerfile index 73b5bb6..ee45774 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,12 +6,12 @@ # You should have received a copy of the license along with this # work. If not, see . -FROM debian:bullseye-slim +FROM debian:bookworm-slim ENV DEBIAN_FRONTEND=noninteractive RUN apt update -RUN apt install -y python3.9 python3-virtualenv python3-pip git python3.9-venv weasyprint +RUN apt install -y python3.11 python3-virtualenv python3-pip git python3-venv weasyprint RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --upgrade virtualenv -- 2.47.3 From 5ce727e60ca94d25572560420a3a894ce4a87884 Mon Sep 17 00:00:00 2001 From: cclecle Date: Wed, 1 Nov 2023 00:08:34 +0000 Subject: [PATCH 03/20] fix dockerfile for debian 12 --- Dockerfile | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index ee45774..9c2c266 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,8 +11,4 @@ FROM debian:bookworm-slim ENV DEBIAN_FRONTEND=noninteractive RUN apt update -RUN apt install -y python3.11 python3-virtualenv python3-pip git python3-venv weasyprint - -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --upgrade virtualenv -RUN python3 -m pip install --upgrade setuptools wheel build \ No newline at end of file +RUN apt install -y python3.11 python3-virtualenv python3-pip git python3-venv weasyprint \ No newline at end of file -- 2.47.3 From f9b016d8450c14221493d345f5bae2b74b0a63f7 Mon Sep 17 00:00:00 2001 From: cclecle Date: Wed, 1 Nov 2023 02:12:56 +0000 Subject: [PATCH 04/20] implement server-side Set-Cookie --- .settings/org.eclipse.core.resources.prefs | 1 + src/pyrestresource/rest_request.py | 12 --- src/pyrestresource/rest_resource.py | 81 +++++++++++++++--- src/pyrestresource/rest_resource_handler.py | 10 +-- src/pyrestresource/rest_resource_plugin.py | 18 +++- src/pyrestresource/rest_resource_walker.py | 7 +- test/test_rest_login.py | 91 ++++++++++++++++++++- test/test_rest_resource.py | 2 +- test/test_rest_resource_plugins.py | 12 +-- test/test_rest_webserver.py | 2 +- 10 files changed, 189 insertions(+), 47 deletions(-) diff --git a/.settings/org.eclipse.core.resources.prefs b/.settings/org.eclipse.core.resources.prefs index 59b1ede..f39f80c 100644 --- a/.settings/org.eclipse.core.resources.prefs +++ b/.settings/org.eclipse.core.resources.prefs @@ -1,4 +1,5 @@ eclipse.preferences.version=1 encoding//src/pyrestresource/__init__.py=utf-8 encoding//src/pyrestresource/__metadata__.py=utf-8 +encoding//src/pyrestresource/rest_resource.py=utf-8 encoding/=UTF-8 diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index 951456d..a24a80a 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -114,18 +114,6 @@ class RestRequest(Generic[_T_RestRequestParams]): self._saved_url_stack: list[str] self.url_stack_index: int - # detecting Optional fields in type_request_params (to extract real type) - # if False: # deprecated => Generic - # if get_origin(type_request_params) is Union: - # datatype = get_args(type_request_params) - # if len(datatype) == 2: - # if datatype[0] is type(None): - # type_request_params = datatype[1] - # elif datatype[1] is type(None): - # type_request_params = datatype[0] - # else: - # raise RuntimeError("Union is only allowed to describe Optional (e.g. Union[XXX,None])") - # = updating request from a previous one = if origin_request: self.__dict__ = origin_request.__dict__.copy() diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index 987d8d2..7390beb 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -95,6 +95,34 @@ class RestResourceWalker_Root__tree_exclude(RestResourceWalker_Root): ] +class pluginCTX: + cookies: dict[str, str] = dict() + + +class RestResourceWalker_Sub_RestResourceBase__init_pluginCTX(RestResourceWalker_Sub_RestResourceBase): + _pluginCTX: pluginCTX = pluginCTX() + + def process(self): + # import pprint + + # print(f"hey: {self.resource}") + # pprint.pprint(self.resource) + # print(type(self.resource)) + # print(self.annotation._plugins_) + + for plugin in self.annotation._plugins_.values(): + # print("SET COOKIE") + plugin.cookies = self._pluginCTX.cookies + + +class RestResourceWalker_Root__init_pluginCTX(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase__init_pluginCTX, + ] + + class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): def process(self) -> None: datatype = get_args(self.annotation) @@ -117,7 +145,7 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): and type(self.resource.json_schema_extra) is dict and "plugin" in self.resource.json_schema_extra ): - plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"]() + plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"] if not isinstance(plugin_dict, ResourcePlugin_dict): raise RuntimeError("Wrong plugin signature provided") self.parent.annotation._plugins_[self.resource_name] = plugin_dict @@ -141,7 +169,7 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi self.parent.annotation._primary_key_ = self.resource_name if "plugin" in self.resource.json_schema_extra and self.resource.json_schema_extra["plugin"]: - plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"]() + plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"] if not isinstance(plugin_field, ResourcePlugin_field): raise RuntimeError("Wrong plugin signature provided") self.parent.annotation._plugins_[self.resource_name] = plugin_field @@ -169,7 +197,7 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_ and type(self.resource.json_schema_extra) is dict and "plugin" in self.resource.json_schema_extra ): - plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"]() + plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase): raise RuntimeError("Wrong plugin signature provided") self.parent.annotation._plugins_[self.resource_name] = plugin_resource @@ -190,6 +218,7 @@ def register_rest_rootpoint(klass: type[RestResourceBase]): class RestResourceBase(ABC, BaseModel, validate_assignment=True): + _resp_cookies: ClassVar[dict[str, str]] = dict() _dict_key_type_: ClassVar[dict[str, T_T_DictKey]] = {} _dict_value_type_: ClassVar[dict[str, T_T_DictValues]] = {} _model_dump_excluded_: ClassVar[dict[str, bool]] = {} @@ -226,8 +255,16 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if b"content-type" in scope["headers"]: assert scope["headers"][b"content-type"] == b"application/json" + # import pprint + + # print("----REC HEADER ---") + # pprint.pprint(scope["headers"]) + body = await self.read_body(receive) verb = rsrc_verb[scope["method"]] + + type(self)._resp_cookies = dict() + result = self.process_request( scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8") ) @@ -236,18 +273,25 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if verb in (rsrc_verb.POST, rsrc_verb.PUT): status = 201 - await send( - { - "type": "http.response.start", - "status": status, - "headers": [ - [b"content-type", b"application/json"], - ], - } - ) + header_resp = { + "type": "http.response.start", + "status": status, + "headers": [ + [b"content-type", b"application/json"], + ], + } + + for name, value in type(self)._resp_cookies.items(): + header_resp["headers"].append(["Set-Cookie", f"{name}={value}"]) + + # print("----SENT HEADER ---") + # pprint.pprint(header_resp) + await send(header_resp) + body = None if result: body = result.encode("utf-8") + await send( { "type": "http.response.body", @@ -256,7 +300,11 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): ) def process_request( - self, url: str, verb: rsrc_verb = rsrc_verb.GET, data_json: Optional[str] = None, query_string: Optional[str] = None + self, + url: str, + verb: rsrc_verb = rsrc_verb.GET, + data_json: Optional[str] = None, + query_string: Optional[str] = None, ) -> Optional[str]: from .rest_resource_handler import ( ResourceHandler, @@ -267,9 +315,16 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if data_json: data = json.loads(data_json) + RestResourceWalker_Sub_RestResourceBase__init_pluginCTX._pluginCTX.cookies = type(self)._resp_cookies + RestResourceWalker_Root__init_pluginCTX(self).process() + ressource: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string) result = ressource.process_verb() + # print("OOO") + # print(type(self)._resp_cookies) + # print("OOO2") + if isinstance(result, RestResourceBase): exclude: Optional[dict[str, bool]] = None raw_exclude = RestResourceWalker_Root__tree_exclude(result).process() diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index 56d304d..cc35b67 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -613,19 +613,19 @@ class ResourceHandler_simple( value = self.req.get_data() if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: - print("PLUGIN FOUND") + # print("PLUGIN FOUND") plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)], ) - print(value) + # print(value) value = plugin_simple.handle_field_put(value, params) - print(value) + # print(value) - print(self.req.get_resource_origin(1)) + # print(self.req.get_resource_origin(1)) setattr( self.prev_handler.resource, self.req.get_resource_origin(1), value, ) - print(self.prev_handler.resource) + # print(self.prev_handler.resource) diff --git a/src/pyrestresource/rest_resource_plugin.py b/src/pyrestresource/rest_resource_plugin.py index 10a5570..1ad2fb2 100644 --- a/src/pyrestresource/rest_resource_plugin.py +++ b/src/pyrestresource/rest_resource_plugin.py @@ -25,8 +25,20 @@ if TYPE_CHECKING or True: ) +class ResourcePlugin(Protocol): + cookies: dict[str, str] = dict() + + def set_cookie(self, name: str, value: str): + # print("AAA") + # print(name) + # print(value) + # print(self.cookies) + # print(type(self.cookies)) + self.cookies[name] = value + + @runtime_checkable -class ResourcePlugin_field(Protocol[TV_SupportedRESTFields]): +class ResourcePlugin_field(ResourcePlugin, Protocol[TV_SupportedRESTFields]): @abstractmethod def handle_field_get(self, resource: TV_SupportedRESTFields, params: RestRequestParams_GET) -> TV_SupportedRESTFields: ... @@ -47,7 +59,7 @@ class ResourcePlugin_field_default(ResourcePlugin_field[TV_SupportedRESTFields]) @runtime_checkable -class ResourcePlugin_RestResourceBase(Protocol[TV_RestResourceBase]): +class ResourcePlugin_RestResourceBase(ResourcePlugin, Protocol[TV_RestResourceBase]): @abstractmethod def handle_resource_get( self, @@ -84,7 +96,7 @@ class ResourcePlugin_RestResourceBase_default(ResourcePlugin_RestResourceBase[TV @runtime_checkable -class ResourcePlugin_dict(Protocol[_T_DictKey, _T_DictValues]): +class ResourcePlugin_dict(ResourcePlugin, Protocol[_T_DictKey, _T_DictValues]): @abstractmethod def handle_dict_get_keys( self, diff --git a/src/pyrestresource/rest_resource_walker.py b/src/pyrestresource/rest_resource_walker.py index caa247c..d30d3c1 100644 --- a/src/pyrestresource/rest_resource_walker.py +++ b/src/pyrestresource/rest_resource_walker.py @@ -95,8 +95,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): self.subdatatype = get_args(self.annotation) - # self.info() - + """ def info(self) -> None: print(f"{type(self).__name__}->info()") print("==========================") @@ -119,6 +118,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): print(f"{id(_rsrc.annotation)}:{_rsrc.annotation}") _rsrc = _rsrc.parent print("-------------------") + """ @classmethod def init_sub(cls, walker: RestResourceWalker_Root) -> None: @@ -193,6 +193,8 @@ class RestResourceWalker_Sub_T_Dict(RestResourceWalker_Sub): return (_type_resource is dict), _anno, _optional def get_sub_resources(self) -> list[tuple[str, FieldInfo]]: + # print("????") + # print(self.subdatatype[1]) return [(self.resource_name, self.subdatatype[1])] def get_future(self) -> Optional[RestResourceWalkerFutureResult]: @@ -255,6 +257,7 @@ class RestResourceWalker_Root: ) if sub_walker_initial is not None: + sub_walker_initial.process() sub_walker_initial.get_future() resource_list: list[tuple[str, FieldInfo | Type["RestResourceBase"], RestResourceWalker_Sub]] = [ (subresource_name, subresource, sub_walker_initial) diff --git a/test/test_rest_login.py b/test/test_rest_login.py index 29114d1..87650f6 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -6,12 +6,21 @@ from pathlib import Path from typing import Optional, Annotated from pydantic import Field from uuid import UUID, uuid4 +from time import time, sleep from time import time import json +import uvicorn +import socket +import requests +from contextlib import closing +from multiprocessing import Process +from secrets import token_hex print(__name__) print(__package__) +from pydantic import BaseModel + from src.pyrestresource import ( register_rest_rootpoint, RestResourceBase, @@ -32,7 +41,14 @@ chdir(testdir_path.parent.resolve()) # to allow mock-ing, all the tested classes are in a function def init_classes(): + class UserLogin(BaseModel): + username: str + secret: str + token: Optional[str] = None + class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): + ar_UserLogin: list[UserLogin] = [UserLogin(username="chacha", secret="123456")] + def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login: print("hook GET") print(resource) @@ -41,8 +57,22 @@ def init_classes(): def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login: print("hook PUT") - print(resource) - print(params) + + print(resource.username) + print(resource.secret) + + for _UserLogin in self.ar_UserLogin: + if _UserLogin.username == resource.username and _UserLogin.secret == resource.secret: + print("user connected") + _UserLogin.token = token_hex(16) + self.set_cookie("test", _UserLogin.token) + print(f"generated token: {_UserLogin.token}") + return resource + print("login NOT found") + # print(resource) + # print(resource.username) + # print(resource.secret) + # print(params) return resource class Login(RestResourceBase): @@ -54,7 +84,7 @@ def init_classes(): class RootApp(RestResourceBase): login: Login = Field( default=Login(), - plugin=ResourcePlugin_Login, + plugin=ResourcePlugin_Login(), ) # this add the classes to globals to allow using them later on @@ -63,6 +93,21 @@ def init_classes(): globals()[RootApp.__name__] = RootApp +def find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + hostname = socket.gethostname() + IPAddr = socket.gethostbyname(hostname) + return "localhost", s.getsockname()[1] + + +def launch_server(ip, port): + print(f"port2={port}") + init_classes() + uvicorn.run(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True) + + class Test_RestAPI_LOGIN(unittest.TestCase): def setUp(self) -> None: chdir(testdir_path.parent.resolve()) @@ -73,8 +118,46 @@ class Test_RestAPI_LOGIN(unittest.TestCase): result = self.testapp.process_request("/login", rsrc_verb.GET) print(result) - result = self.testapp.process_request("/login", rsrc_verb.PUT, '{"username":"toto","secret":"123456"}') + result = self.testapp.process_request("/login", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') print(result) result = self.testapp.process_request("/login", rsrc_verb.GET) print(result) + + +class Test_RestAPI_LOGIN_Web(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + + def test_login(self): + ip, port = find_free_port() + print(f"ip1={ip}") + print(f"port1={port}") + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + try: + # Login in + + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "chacha", "secret": "123456"}, + ) + print(response) + print(response.headers) + self.assertEqual(response.status_code, 201) + + response = s.get(f"http://{ip}:{port}/login") + response = s.get(f"http://{ip}:{port}/login") + response = s.get(f"http://{ip}:{port}/login") + + finally: + proc.terminate() + s.close() diff --git a/test/test_rest_resource.py b/test/test_rest_resource.py index a54e0c8..5120150 100644 --- a/test/test_rest_resource.py +++ b/test/test_rest_resource.py @@ -472,7 +472,7 @@ class Test_RestAPI_PERFO(unittest.TestCase): init_classes() self.testapp = RootApp() - @unittest.skip + # @unittest.skip def test_perf_dict(self): print(f"LIB INTERNAL PERF TEST") n_loop = 10000 diff --git a/test/test_rest_resource_plugins.py b/test/test_rest_resource_plugins.py index ac40326..b10cdda 100644 --- a/test/test_rest_resource_plugins.py +++ b/test/test_rest_resource_plugins.py @@ -47,13 +47,13 @@ def init_classes(): class Info_get(RestResourceBase): # test plugin injection within annotation # + test plugin on a simple field - version: Annotated[str, Field(plugin=ResourcePlugin_version_get)] + version: Annotated[str, Field(plugin=ResourcePlugin_version_get())] api_version: str class Info_put(RestResourceBase): # test plugin injection within annotation # + test plugin on a simple field - version: Annotated[str, Field(plugin=ResourcePlugin_version_put)] + version: Annotated[str, Field(plugin=ResourcePlugin_version_put())] api_version: str @register_rest_rootpoint @@ -62,7 +62,7 @@ def init_classes(): # + test plugin on a RestResourceBase field info: Info_get = Field( default=Info_get(version="0.0.1", api_version="0.0.2"), - plugin=ResourcePlugin_Info, + plugin=ResourcePlugin_Info(), ) info_put: Info_put = Field( default=Info_put(version="0.0.1", api_version="0.0.2"), @@ -83,7 +83,7 @@ def init_bad_plugin1(): return resource class TestResource(RestResourceBase): - tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] + tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource())] @register_rest_rootpoint class RootApp2(RestResourceBase): @@ -99,7 +99,7 @@ def init_bad_plugin2(): return resource class TestResource(RestResourceBase): - tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] + tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource())] @register_rest_rootpoint class RootApp2(RestResourceBase): @@ -114,7 +114,7 @@ def init_bad_plugin3(): pass class TestResource(RestResourceBase): - tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] + tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource())] @register_rest_rootpoint class RootApp2(RestResourceBase): diff --git a/test/test_rest_webserver.py b/test/test_rest_webserver.py index 28edf77..76033a9 100644 --- a/test/test_rest_webserver.py +++ b/test/test_rest_webserver.py @@ -287,7 +287,7 @@ class Test_RestAPI_WebServer(unittest.TestCase): proc.terminate() s.close() - @unittest.skip + # @unittest.skip def test_perf_dict(self): print(f"SOCKET PERF TEST") n_loop = 10000 -- 2.47.3 From 2251b1d5e9b81eb08576f8dc3bdfe6452aaa3a59 Mon Sep 17 00:00:00 2001 From: cclecle Date: Thu, 2 Nov 2023 01:00:30 +0000 Subject: [PATCH 05/20] optimization (WIP) --- .settings/org.eclipse.core.resources.prefs | 1 + src/pyrestresource/rest_request.py | 46 ++++++------ src/pyrestresource/rest_resource.py | 67 ++++++++---------- src/pyrestresource/rest_resource_handler.py | 37 ++++++---- src/pyrestresource/rest_resource_plugin.py | 8 ++- src/pyrestresource/rest_resource_walker.py | 28 +++----- test/test_rest_login.py | 8 +-- test/test_rest_resource_plugins.py | 12 ++-- test/test_rest_resource_walker.py | 78 +++++++++------------ 9 files changed, 136 insertions(+), 149 deletions(-) diff --git a/.settings/org.eclipse.core.resources.prefs b/.settings/org.eclipse.core.resources.prefs index f39f80c..e658763 100644 --- a/.settings/org.eclipse.core.resources.prefs +++ b/.settings/org.eclipse.core.resources.prefs @@ -2,4 +2,5 @@ eclipse.preferences.version=1 encoding//src/pyrestresource/__init__.py=utf-8 encoding//src/pyrestresource/__metadata__.py=utf-8 encoding//src/pyrestresource/rest_resource.py=utf-8 +encoding//src/pyrestresource/rest_resource_handler_walker.py=utf-8 encoding/=UTF-8 diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index a24a80a..23ff34e 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -62,22 +62,24 @@ class RequestFactory( return RestRequest[RestRequestParams_DELETE](self.cls_RestRequestParams_DELETE, url, verb, data, query_string) raise RuntimeError("Invalid Verb") - def update_RestRequest(self, origin_request: RestRequest) -> RestRequest: + def update_RestRequest(self, request: RestRequest) -> None: """create an updated copy of a RestRequest object based on a different LUT_verb configuration Args: origin_request: the original request """ # /!\ mypy seems not being able to propagate typevar to composed classes - if origin_request.verb is rsrc_verb.GET: - return RestRequest[RestRequestParams_GET](self.cls_RestRequestParams_GET, None, None, None, None, origin_request) - if origin_request.verb is rsrc_verb.PUT: - return RestRequest[RestRequestParams_PUT](self.cls_RestRequestParams_PUT, None, None, None, None, origin_request) - if origin_request.verb is rsrc_verb.POST: - return RestRequest[RestRequestParams_POST](self.cls_RestRequestParams_POST, None, None, None, None, origin_request) - if origin_request.verb is rsrc_verb.DELETE: - return RestRequest[RestRequestParams_DELETE](self.cls_RestRequestParams_DELETE, None, None, None, None, origin_request) - raise RuntimeError("Invalid Verb") + if request.verb is rsrc_verb.GET: + request.update_ReqParams(self.cls_RestRequestParams_GET) + elif request.verb is rsrc_verb.PUT: + request.update_ReqParams(self.cls_RestRequestParams_PUT) + elif request.verb is rsrc_verb.POST: + request.update_ReqParams(self.cls_RestRequestParams_POST) + elif request.verb is rsrc_verb.DELETE: + request.update_ReqParams(self.cls_RestRequestParams_DELETE) + else: + raise RuntimeError("Invalid Verb") + return class RestRequest(Generic[_T_RestRequestParams]): @@ -87,21 +89,22 @@ class RestRequest(Generic[_T_RestRequestParams]): def __init__( self, type_request_params: type[_T_RestRequestParams], - url: Optional[str] = None, - verb: Optional[rsrc_verb] = None, + url: str, + verb: rsrc_verb, data: Optional[dict[str, T_SupportedRESTFields]] = None, query_string: Optional[str] = None, - origin_request: Optional[RestRequest] = None, + incoming_cookie: dict[str, str] = {}, + outgoing_cookie: dict[str, str] = {}, ) -> None: """class to handle a request context, that will be kept and updated while walking url parts Args: + type_request_params: type of the request param url: http url of the request verb: http verb received data: data associated with the request - type_request_params: type of the request param - origin_request: orginial request in case of updates. In this case, all other argument - but type_request_params - are ignored and inherited from the origin_request + query_string: query arguments after url (eg: ?arg1=value1&arg2=value2 ...) """ # defining all types @@ -113,14 +116,8 @@ class RestRequest(Generic[_T_RestRequestParams]): self.url_stack: list[str] self._saved_url_stack: list[str] self.url_stack_index: int - - # = updating request from a previous one = - if origin_request: - self.__dict__ = origin_request.__dict__.copy() - if type_request_params: - self.ReqParams = type_request_params(**self._saved_url_params) - # print("request updated") - return + self.incoming_cookie: dict[str, str] = incoming_cookie + self.outgoing_cookie: dict[str, str] = outgoing_cookie # = or create a fresh one = if url is None or verb is None or data is None: @@ -144,6 +141,9 @@ class RestRequest(Generic[_T_RestRequestParams]): self._saved_url_stack = self.url_stack.copy() self.url_stack_index = 0 + def update_ReqParams(self, type_request_params: type[_T_RestRequestParams]): + self.ReqParams = type_request_params(**self._saved_url_params) + def _parse_url(self, url: str) -> None: # remove repeated slash ('/') url = sub(r"\/{2,}", "/", url) diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index 7390beb..4ba860c 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -23,7 +23,7 @@ from typing import ( TYPE_CHECKING, ) import json -from pydantic.fields import FieldInfo +from pydantic.fields import FieldInfo, Field from pydantic import BaseModel from .helpers import _JSONEncoder @@ -43,6 +43,8 @@ from .rest_resource_walker import ( RestResourceWalker_Sub_RestResourceBase, ) +from .rest_request import RestRequest + if TYPE_CHECKING: from .rest_types import ( T_ListIndex, @@ -95,34 +97,6 @@ class RestResourceWalker_Root__tree_exclude(RestResourceWalker_Root): ] -class pluginCTX: - cookies: dict[str, str] = dict() - - -class RestResourceWalker_Sub_RestResourceBase__init_pluginCTX(RestResourceWalker_Sub_RestResourceBase): - _pluginCTX: pluginCTX = pluginCTX() - - def process(self): - # import pprint - - # print(f"hey: {self.resource}") - # pprint.pprint(self.resource) - # print(type(self.resource)) - # print(self.annotation._plugins_) - - for plugin in self.annotation._plugins_.values(): - # print("SET COOKIE") - plugin.cookies = self._pluginCTX.cookies - - -class RestResourceWalker_Root__init_pluginCTX(RestResourceWalker_Root): - cls_RestResourceWalker_Sub = [ - RestResourceWalker_Sub_T_Dict, - RestResourceWalker_Sub_RestFields, - RestResourceWalker_Sub_RestResourceBase__init_pluginCTX, - ] - - class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): def process(self) -> None: datatype = get_args(self.annotation) @@ -139,6 +113,9 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): self.parent.annotation._dict_value_type_[self.resource_name] = datatype[1] # pylint: disable=protected-access self.parent.annotation._model_dump_excluded_[self.resource_name] = True # pylint: disable=protected-access + self.resource.exclude = True + self.parent.resource.model_rebuild(force=True) + if ( isinstance(self.resource, FieldInfo) and self.resource.json_schema_extra is not None @@ -158,11 +135,19 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFields): def process(self) -> None: if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + import pprint + + # print("1aaaaaaaaaa") + # pprint.pprint(self.resource.json_schema_extra) + # pprint.pprint(self.annotation) + # pprint.pprint(self.resource.exclude) if ( isinstance(self.resource, FieldInfo) and self.resource.json_schema_extra is not None and type(self.resource.json_schema_extra) is dict ): + # print("aaaaaaaaaa") + if "primary_key" in self.resource.json_schema_extra and self.resource.json_schema_extra["primary_key"] is True: if self.parent.annotation._primary_key_ is not None: raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}") @@ -190,6 +175,8 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_ # preprocessing types / structure if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): self.parent.annotation._model_dump_excluded_[self.resource_name] = True + self.resource.exclude = True + self.parent.resource.model_rebuild(force=True) if ( isinstance(self.resource, FieldInfo) @@ -229,6 +216,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): ResourcePlugin_field | ResourcePlugin_RestResourceBase | ResourcePlugin_dict, ] ] = {} + _request: Optional[RestRequest] = None def update(self, **new_data): for field, value in new_data.items(): @@ -263,12 +251,13 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): body = await self.read_body(receive) verb = rsrc_verb[scope["method"]] - type(self)._resp_cookies = dict() - + self._request = None result = self.process_request( scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8") ) + assert self._request != None + status = 200 if verb in (rsrc_verb.POST, rsrc_verb.PUT): status = 201 @@ -281,7 +270,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): ], } - for name, value in type(self)._resp_cookies.items(): + for name, value in self._request.outgoing_cookie.items(): header_resp["headers"].append(["Set-Cookie", f"{name}={value}"]) # print("----SENT HEADER ---") @@ -315,10 +304,9 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if data_json: data = json.loads(data_json) - RestResourceWalker_Sub_RestResourceBase__init_pluginCTX._pluginCTX.cookies = type(self)._resp_cookies - RestResourceWalker_Root__init_pluginCTX(self).process() - ressource: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string) + self._request = ressource.get_request() + result = ressource.process_verb() # print("OOO") @@ -326,10 +314,11 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): # print("OOO2") if isinstance(result, RestResourceBase): - exclude: Optional[dict[str, bool]] = None - raw_exclude = RestResourceWalker_Root__tree_exclude(result).process() - exclude = next(iter(raw_exclude.values())) - return json.dumps(result.model_dump(mode="json", exclude=exclude)) + # exclude: Optional[dict[str, bool]] = None + # raw_exclude = RestResourceWalker_Root__tree_exclude(result).process() + # exclude = next(iter(raw_exclude.values())) + # return json.dumps(result.model_dump(mode="json", exclude=exclude)) + return json.dumps(result.model_dump(mode="json")) if result is not None: return json.dumps(result, cls=_JSONEncoder) diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index cc35b67..218f412 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -95,7 +95,9 @@ class ResourceHandler( self.req: RestRequest if prev_handler is not None: self.prev_handler = prev_handler - self.req = self._request_factory.update_RestRequest(self.prev_handler.req) + self.req = prev_handler.get_request() + self._request_factory.update_RestRequest(self.req) + elif None in [url, verb]: raise RuntimeError("if req not set, url,verb must be setted") else: @@ -129,6 +131,9 @@ class ResourceHandler( cls._ar_resource_handler_cls_.append(other_cls) return other_cls + def get_request(self) -> RestRequest: + return self.req + def process_verb( self, ) -> Optional[_T_Resource | T_DictKey | list[T_DictKey]]: @@ -455,10 +460,14 @@ class ResourceHandler_RestResourceBase( if self.req.get_resource_origin(0) == "/": return - if ( - self.req.get_resource_origin(0) not in self.resource.model_fields - or self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True - ): + # print("======") + # print(self.req.get_resource_origin(0)) + # print(len(self.req.get_url_stack())) + # print(self.resource._model_dump_excluded_) + # print(type(self.resource)) + # print(self.resource.exclude) + + if self.req.get_resource_origin(0) not in self.resource.model_fields: raise RuntimeError(f"Unknown or not allowed field access detected: {self.req.get_url_stack()}") def _handle_process_get(self, params) -> RestResourceBase: @@ -471,11 +480,11 @@ class ResourceHandler_RestResourceBase( for key, attr in self.resource.model_fields.items(): if key in self.resource._plugins_: if isinstance(self.resource._plugins_[key], ResourcePlugin_field): - plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, self.resource._plugins_[key]) + plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, self.resource._plugins_[key](self.req)) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_field_get(value, params)) elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): - plugin_field: ResourcePlugin_field = cast(ResourcePlugin_RestResourceBase, self.resource._plugins_[key]) + plugin_field: ResourcePlugin_field = cast(ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req)) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_resource_get(value, params)) @@ -496,14 +505,14 @@ class ResourceHandler_RestResourceBase( if isinstance(self.resource._plugins_[key], ResourcePlugin_field): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.resource._plugins_[key], + self.resource._plugins_[key](self.req), ) value = plugin_rsrc.handle_field_get(value, params) elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.resource._plugins_[key], + self.resource._plugins_[key](self.req), ) value = plugin_rsrc.handle_resource_get(value, params) @@ -523,7 +532,7 @@ class ResourceHandler_RestResourceBase( for key, attr in _new_resrc.model_fields.items(): if key in _new_resrc._plugins_: if isinstance(_new_resrc._plugins_[key], ResourcePlugin_field): - plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, _new_resrc._plugins_[key]) + plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, _new_resrc._plugins_[key](self.req)) value = getattr(_new_resrc, key) setattr(_new_resrc, key, plugin_field.handle_field_put(value, params)) @@ -539,7 +548,7 @@ class ResourceHandler_RestResourceBase( if key in self.prev_handler.prev_handler.resource._plugins_: plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.prev_handler.prev_handler.resource._plugins_[key], + self.prev_handler.prev_handler.resource._plugins_[key](self.req), ) _new_resrc = plugin_rsrc.handle_dict_elem_put(_new_resrc, params) # element is within a RestResourceBase @@ -548,7 +557,7 @@ class ResourceHandler_RestResourceBase( if key in self.prev_handler.resource._plugins_: plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.prev_handler.resource._plugins_[key], + self.prev_handler.resource._plugins_[key](self.req), ) _new_resrc = plugin_rsrc.handle_resource_put(_new_resrc, params) @@ -597,7 +606,7 @@ class ResourceHandler_simple( if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, - self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)], + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req), ) return plugin_simple.handle_field_get(self.resource, params) @@ -616,7 +625,7 @@ class ResourceHandler_simple( # print("PLUGIN FOUND") plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, - self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)], + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req), ) # print(value) value = plugin_simple.handle_field_put(value, params) diff --git a/src/pyrestresource/rest_resource_plugin.py b/src/pyrestresource/rest_resource_plugin.py index 1ad2fb2..9308aa6 100644 --- a/src/pyrestresource/rest_resource_plugin.py +++ b/src/pyrestresource/rest_resource_plugin.py @@ -10,6 +10,7 @@ from .rest_types import ( TV_RestResourceBase, ) +from .rest_request import RestRequest if TYPE_CHECKING or True: from .rest_request_opt import ( @@ -26,15 +27,16 @@ if TYPE_CHECKING or True: class ResourcePlugin(Protocol): - cookies: dict[str, str] = dict() + def __init__(self, request: RestRequest) -> None: + self.request: RestRequest = request - def set_cookie(self, name: str, value: str): + def set_resp_cookie(self, name: str, value: str): # print("AAA") # print(name) # print(value) # print(self.cookies) # print(type(self.cookies)) - self.cookies[name] = value + self.request.outgoing_cookie[name] = value @runtime_checkable diff --git a/src/pyrestresource/rest_resource_walker.py b/src/pyrestresource/rest_resource_walker.py index d30d3c1..808eb5d 100644 --- a/src/pyrestresource/rest_resource_walker.py +++ b/src/pyrestresource/rest_resource_walker.py @@ -55,12 +55,13 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): resource_name: str, resource: FieldInfo | Type["RestResourceBase"], parent: Optional[RestResourceWalker_Sub] = None, + argument: Optional[any] = None, ) -> Optional[RestResourceWalker_Sub]: for sub in subs: _is_valid, _anno, _optional = sub.check_type(resource) if _is_valid is True: - return sub(resource_name, resource, parent, _anno, _optional) + return sub(resource_name, resource, parent, _anno, _optional, argument) raise RuntimeError(f"Incompatible Field Found: {type(resource).__name__}") return None @@ -70,8 +71,10 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): resource: FieldInfo | Type["RestResourceBase"], parent: Optional[RestResourceWalker_Sub] = None, annotation: Optional[type["RestResourceBase"]] = None, - optional: Optional[bool] = None, + _optional: Optional[bool] = None, + argument: Optional[any] = None, ): + self.argument: any = argument self.resource_name: str = resource_name self.resource: FieldInfo | Type["RestResourceBase"] = resource self.parent: Optional[RestResourceWalker_Sub] = parent @@ -84,11 +87,11 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): self.annotation: type["RestResourceBase"] self.optional: bool - if annotation is None or optional is None: + if annotation is None or _optional is None: self.annotation, self.optional = self.ProcessAnnotation(resource) else: self.annotation = annotation - self.optional = optional + self.optional = _optional if self.annotation is None: raise RuntimeError("Only annotated types are allowed in RestResourceBase derived classes") @@ -120,10 +123,6 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): print("-------------------") """ - @classmethod - def init_sub(cls, walker: RestResourceWalker_Root) -> None: - pass - @abstractmethod def get_future(self) -> Optional[RestResourceWalkerFutureResult]: return self.future_result @@ -238,6 +237,7 @@ class RestResourceWalker_Root: ] def __init__(self, resource: "RestResourceBase" | Type["RestResourceBase"]) -> None: + self.subwalker_argument: any = None from .rest_resource import RestResourceBase self.resource: Type["RestResourceBase"] @@ -246,14 +246,11 @@ class RestResourceWalker_Root: else: self.resource = resource - def process(self, deep_limit: Optional[int] = None) -> Optional[TV_RestResourceWalkerFutureResult]: + def process(self, argument: Optional[any] = None, deep_limit: Optional[int] = None) -> Optional[TV_RestResourceWalkerFutureResult]: current_deep: int = 0 - for cls_Sub in self.cls_RestResourceWalker_Sub: - _self = self - cls_Sub.init_sub(_self) sub_walker_initial: Optional[RestResourceWalker_Sub] = RestResourceWalker_Sub.get( - self.cls_RestResourceWalker_Sub, "/", self.resource, None + self.cls_RestResourceWalker_Sub, "/", self.resource, None, argument ) if sub_walker_initial is not None: @@ -270,10 +267,7 @@ class RestResourceWalker_Root: new_resource_list = [] for resource_name, resource, parent_sub_walker in resource_list: sub_walker = RestResourceWalker_Sub.get( - self.cls_RestResourceWalker_Sub, - resource_name, - resource, - parent_sub_walker, + self.cls_RestResourceWalker_Sub, resource_name, resource, parent_sub_walker, argument ) if sub_walker is not None: sub_walker.process() diff --git a/test/test_rest_login.py b/test/test_rest_login.py index 87650f6..2fdf8f1 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -65,7 +65,7 @@ def init_classes(): if _UserLogin.username == resource.username and _UserLogin.secret == resource.secret: print("user connected") _UserLogin.token = token_hex(16) - self.set_cookie("test", _UserLogin.token) + self.set_resp_cookie("test", _UserLogin.token) print(f"generated token: {_UserLogin.token}") return resource print("login NOT found") @@ -84,7 +84,7 @@ def init_classes(): class RootApp(RestResourceBase): login: Login = Field( default=Login(), - plugin=ResourcePlugin_Login(), + plugin=ResourcePlugin_Login, ) # this add the classes to globals to allow using them later on @@ -155,8 +155,8 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): self.assertEqual(response.status_code, 201) response = s.get(f"http://{ip}:{port}/login") - response = s.get(f"http://{ip}:{port}/login") - response = s.get(f"http://{ip}:{port}/login") + + response = s.get(f"http://{ip}:{port}/") finally: proc.terminate() diff --git a/test/test_rest_resource_plugins.py b/test/test_rest_resource_plugins.py index b10cdda..ac40326 100644 --- a/test/test_rest_resource_plugins.py +++ b/test/test_rest_resource_plugins.py @@ -47,13 +47,13 @@ def init_classes(): class Info_get(RestResourceBase): # test plugin injection within annotation # + test plugin on a simple field - version: Annotated[str, Field(plugin=ResourcePlugin_version_get())] + version: Annotated[str, Field(plugin=ResourcePlugin_version_get)] api_version: str class Info_put(RestResourceBase): # test plugin injection within annotation # + test plugin on a simple field - version: Annotated[str, Field(plugin=ResourcePlugin_version_put())] + version: Annotated[str, Field(plugin=ResourcePlugin_version_put)] api_version: str @register_rest_rootpoint @@ -62,7 +62,7 @@ def init_classes(): # + test plugin on a RestResourceBase field info: Info_get = Field( default=Info_get(version="0.0.1", api_version="0.0.2"), - plugin=ResourcePlugin_Info(), + plugin=ResourcePlugin_Info, ) info_put: Info_put = Field( default=Info_put(version="0.0.1", api_version="0.0.2"), @@ -83,7 +83,7 @@ def init_bad_plugin1(): return resource class TestResource(RestResourceBase): - tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource())] + tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] @register_rest_rootpoint class RootApp2(RestResourceBase): @@ -99,7 +99,7 @@ def init_bad_plugin2(): return resource class TestResource(RestResourceBase): - tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource())] + tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] @register_rest_rootpoint class RootApp2(RestResourceBase): @@ -114,7 +114,7 @@ def init_bad_plugin3(): pass class TestResource(RestResourceBase): - tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource())] + tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] @register_rest_rootpoint class RootApp2(RestResourceBase): diff --git a/test/test_rest_resource_walker.py b/test/test_rest_resource_walker.py index a467e37..d562dc5 100644 --- a/test/test_rest_resource_walker.py +++ b/test/test_rest_resource_walker.py @@ -1,7 +1,7 @@ from __future__ import annotations import unittest -from typing import Annotated, Optional +from typing import Optional, cast from os import chdir from pathlib import Path @@ -28,50 +28,39 @@ chdir(testdir_path.parent.resolve()) class RestResourceWalker_Sub_T_Dict_TEST_Print(RestResourceWalker_Sub_T_Dict): - counter: dict[str, int] = {} - - @classmethod - def init_sub(cls, walker: RestResourceWalker_Root) -> None: - cls.counter = {} + cls_counter: dict[str, int] = {} def process(self) -> None: - if self.resource_name not in self.counter: - self.counter[self.resource_name] = 0 - self.counter[self.resource_name] = self.counter[self.resource_name] + 1 + counter = self.cls_counter + if self.resource_name not in counter: + counter[self.resource_name] = 0 + counter[self.resource_name] = counter[self.resource_name] + 1 - print(f"DICT {self.resource_name} {self.counter[self.resource_name]}") + print(f"DICT {self.resource_name} {counter[self.resource_name]}") class RestResourceWalker_Sub_RestFields_TEST_Print(RestResourceWalker_Sub_RestFields): - counter: dict[str, int] = {} - - @classmethod - def init_sub(cls, walker: RestResourceWalker_Root) -> None: - cls.counter = {} + cls_counter: dict[str, int] = {} def process(self) -> None: - if self.resource_name not in self.counter: - self.counter[self.resource_name] = 0 - self.counter[self.resource_name] = self.counter[self.resource_name] + 1 + counter = self.cls_counter + if self.resource_name not in counter: + counter[self.resource_name] = 0 + counter[self.resource_name] = counter[self.resource_name] + 1 - print(f"FIELD {self.resource_name} {self.counter[self.resource_name]}") + print(f"FIELD {self.resource_name} {counter[self.resource_name]}") -class RestResourceWalker_Sub_RestResourceBase_TEST_Print( - RestResourceWalker_Sub_RestResourceBase -): - counter: dict[str, int] = {} - - @classmethod - def init_sub(cls, walker: RestResourceWalker_Root) -> None: - cls.counter = {} +class RestResourceWalker_Sub_RestResourceBase_TEST_Print(RestResourceWalker_Sub_RestResourceBase): + cls_counter: dict[str, int] = {} def process(self) -> None: - if self.resource_name not in self.counter: - self.counter[self.resource_name] = 0 - self.counter[self.resource_name] = self.counter[self.resource_name] + 1 + counter = self.cls_counter + if self.resource_name not in counter: + counter[self.resource_name] = 0 + counter[self.resource_name] = counter[self.resource_name] + 1 - print(f"RestResource {self.resource_name} {self.counter[self.resource_name]}") + print(f"RestResource {self.resource_name} {counter[self.resource_name]}") class RestResourceWalker_Root_TEST_Print(RestResourceWalker_Root): @@ -114,11 +103,12 @@ class Test_Walker(unittest.TestCase): init_classes() def test_walk_class(self): + RestResourceWalker_Sub_T_Dict_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestFields_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestResourceBase_TEST_Print.cls_counter = {} test = RestResourceWalker_Root_TEST_Print(RootApp) - with redirect_stdout(StringIO()) as capted_stdout, redirect_stderr( - StringIO() - ) as capted_stderr: - test.process() + with redirect_stdout(StringIO()) as capted_stdout: + test.process({}) self.assertIn("RestResource info 1", capted_stdout.getvalue()) self.assertIn("RestResource info2 1", capted_stdout.getvalue()) self.assertIn("DICT peoples 1", capted_stdout.getvalue()) @@ -133,12 +123,13 @@ class Test_Walker(unittest.TestCase): self.assertIn("FIELD last_name 1", capted_stdout.getvalue()) def test_walk_obj(self): + RestResourceWalker_Sub_T_Dict_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestFields_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestResourceBase_TEST_Print.cls_counter = {} instRootApp = RootApp() test = RestResourceWalker_Root_TEST_Print(instRootApp) - with redirect_stdout(StringIO()) as capted_stdout, redirect_stderr( - StringIO() - ) as capted_stderr: - test.process() + with redirect_stdout(StringIO()) as capted_stdout: + test.process({}) self.assertIn("RestResource info 1", capted_stdout.getvalue()) self.assertIn("RestResource info2 1", capted_stdout.getvalue()) self.assertIn("DICT peoples 1", capted_stdout.getvalue()) @@ -153,11 +144,12 @@ class Test_Walker(unittest.TestCase): self.assertIn("FIELD last_name 1", capted_stdout.getvalue()) def test_walk_obj_nested_RestResource(self): + RestResourceWalker_Sub_T_Dict_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestFields_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestResourceBase_TEST_Print.cls_counter = {} instRootApp = RootApp() test = RestResourceWalker_Root_TEST_Print(instRootApp.info) - with redirect_stdout(StringIO()) as capted_stdout, redirect_stderr( - StringIO() - ) as capted_stderr: - test.process() + with redirect_stdout(StringIO()) as capted_stdout: + test.process({}) self.assertIn("FIELD version 1", capted_stdout.getvalue()) self.assertIn("FIELD api_version 1", capted_stdout.getvalue()) -- 2.47.3 From c3ff00e877ab69e62e8d18a35ed949f170fa2c6b Mon Sep 17 00:00:00 2001 From: cclecle Date: Fri, 3 Nov 2023 13:11:38 +0000 Subject: [PATCH 06/20] implement first ACL version (lack dict support) --- src/pyrestresource/__init__.py | 1 + src/pyrestresource/rest_ACL.py | 41 ++++++ src/pyrestresource/rest_request.py | 17 +++ src/pyrestresource/rest_resource.py | 146 +++++++++++++++----- src/pyrestresource/rest_resource_handler.py | 44 ++++-- src/pyrestresource/rest_types.py | 15 +- test/test_rest_login.py | 47 +++++-- 7 files changed, 250 insertions(+), 61 deletions(-) create mode 100644 src/pyrestresource/rest_ACL.py diff --git a/src/pyrestresource/__init__.py b/src/pyrestresource/__init__.py index 6ee712e..15d39df 100644 --- a/src/pyrestresource/__init__.py +++ b/src/pyrestresource/__init__.py @@ -53,3 +53,4 @@ from .rest_resource_plugin import ( ResourcePlugin_RestResourceBase_default, ResourcePlugin_dict_default, ) +from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule diff --git a/src/pyrestresource/rest_ACL.py b/src/pyrestresource/rest_ACL.py new file mode 100644 index 0000000..a75fa8b --- /dev/null +++ b/src/pyrestresource/rest_ACL.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from pydantic import BaseModel +from enum import Enum, auto + +from .rest_types import rsrc_verb + + +class ACL_target(BaseModel): + pass + + +class ACL_target_user(ACL_target): + name: str + + +class ACL_target_user_Annonymous(ACL_target): + name: str = "__ANNONYMOUS__" + + +class ACL_target_group(ACL_target): + name: str + + +class ACL_target_group_Annonymous(ACL_target): + name: str = "__ANNONYMOUS__" + + +class ACL_target_group_Any(ACL_target_group): + name: str = "__ANY__" + + +class ACL_rule(Enum): + ALLOW = auto() + DENY = auto() + + +class ACL_record(BaseModel): + verbs: list[rsrc_verb] + target: ACL_target + rule: ACL_rule diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index 23ff34e..80f223d 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -25,6 +25,8 @@ from .rest_request_opt import ( _T_RestRequestParams_PUT, ) +from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group, ACL_target_group_Annonymous + class RequestFactory( Generic[ @@ -118,6 +120,9 @@ class RestRequest(Generic[_T_RestRequestParams]): self.url_stack_index: int self.incoming_cookie: dict[str, str] = incoming_cookie self.outgoing_cookie: dict[str, str] = outgoing_cookie + self.user: ACL_target_user = ACL_target_user_Annonymous() + self.group: ACL_target_group = ACL_target_group_Annonymous() + self.result: Optional[str] = None # = or create a fresh one = if url is None or verb is None or data is None: @@ -141,6 +146,18 @@ class RestRequest(Generic[_T_RestRequestParams]): self._saved_url_stack = self.url_stack.copy() self.url_stack_index = 0 + def set_result(self, result: str): + self.result = result + + def get_result(self) -> Optional[str]: + return self.result + + def set_user(self, user: ACL_target_user): + self.user: ACL_target_user = user + + def set_group(self, group: ACL_target_group): + self.group: ACL_target_group = group + def update_ReqParams(self, type_request_params: type[_T_RestRequestParams]): self.ReqParams = type_request_params(**self._saved_url_params) diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index 4ba860c..8a70a0e 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -34,6 +34,16 @@ from .rest_resource_plugin import ( ResourcePlugin_dict, ) +from .rest_ACL import ( + ACL_record, + ACL_target_user, + ACL_target_group, + ACL_target_user_Annonymous, + ACL_target_group_Annonymous, + ACL_target_group_Any, + ACL_rule, +) + from .rest_resource_walker import ( RestResourceWalkerFutureResult, @@ -116,17 +126,26 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): self.resource.exclude = True self.parent.resource.model_rebuild(force=True) + self.parent.annotation._ACL_record_[self.resource_name] = [] + if ( isinstance(self.resource, FieldInfo) and self.resource.json_schema_extra is not None and type(self.resource.json_schema_extra) is dict - and "plugin" in self.resource.json_schema_extra ): - plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"] - if not isinstance(plugin_dict, ResourcePlugin_dict): - raise RuntimeError("Wrong plugin signature provided") - self.parent.annotation._plugins_[self.resource_name] = plugin_dict - # print("ADD DICT PLUGIN") + if "plugin" in self.resource.json_schema_extra: + plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"] + if not isinstance(plugin_dict, ResourcePlugin_dict): + raise RuntimeError("Wrong plugin signature provided") + self.parent.annotation._plugins_[self.resource_name] = plugin_dict + # print("ADD DICT PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") else: raise RuntimeError("dict must be contained in a RestResourceBase") @@ -141,6 +160,9 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi # pprint.pprint(self.resource.json_schema_extra) # pprint.pprint(self.annotation) # pprint.pprint(self.resource.exclude) + + self.parent.annotation._ACL_record_[self.resource_name] = [] + if ( isinstance(self.resource, FieldInfo) and self.resource.json_schema_extra is not None @@ -153,13 +175,20 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}") self.parent.annotation._primary_key_ = self.resource_name - if "plugin" in self.resource.json_schema_extra and self.resource.json_schema_extra["plugin"]: + if "plugin" in self.resource.json_schema_extra: plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"] if not isinstance(plugin_field, ResourcePlugin_field): raise RuntimeError("Wrong plugin signature provided") self.parent.annotation._plugins_[self.resource_name] = plugin_field # print("ADD FIELD PLUGIN") + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") + else: raise RuntimeError("fields must be contained in a RestResourceBase") @@ -171,24 +200,33 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_ setattr(self.annotation, "_model_dump_excluded_", {}) setattr(self.annotation, "_primary_key_", None) setattr(self.annotation, "_plugins_", {}) + setattr(self.annotation, "_ACL_record_", {}) # preprocessing types / structure if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): self.parent.annotation._model_dump_excluded_[self.resource_name] = True self.resource.exclude = True self.parent.resource.model_rebuild(force=True) + self.parent.annotation._ACL_record_[self.resource_name] = [] if ( isinstance(self.resource, FieldInfo) and self.resource.json_schema_extra is not None and type(self.resource.json_schema_extra) is dict - and "plugin" in self.resource.json_schema_extra ): - plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] - if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase): - raise RuntimeError("Wrong plugin signature provided") - self.parent.annotation._plugins_[self.resource_name] = plugin_resource - # print("ADD RESOURCE PLUGIN") + if "plugin" in self.resource.json_schema_extra: + plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] + if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase): + raise RuntimeError("Wrong plugin signature provided") + self.parent.annotation._plugins_[self.resource_name] = plugin_resource + # print("ADD RESOURCE PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") class RestResourceWalker_Root__tree_init(RestResourceWalker_Root): @@ -213,10 +251,55 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): _plugins_: ClassVar[ dict[ str, - ResourcePlugin_field | ResourcePlugin_RestResourceBase | ResourcePlugin_dict, + list[ACL_record], ] ] = {} - _request: Optional[RestRequest] = None + _ACL_record_: ClassVar[ + dict[ + str, + ACL_record, + ] + ] = {} + + def _check_acl(self, user: ACL_target_user, group: ACL_target_group, verb: rsrc_verb, field: str): + print(f"evaluate self ACLs rule: {self._ACL_record_}") + if verb is rsrc_verb.GET and self.model_fields[field].exclude is True: + print("ALLOWED (excluded field)") + return + for acl in self._ACL_record_[field]: + print(f"evaluate ACL rule: {acl}") + if verb in acl.verbs: + if isinstance(acl.target, ACL_target_user): + if user == acl.target: + if acl.rule is ACL_rule.ALLOW: + print("ALLOWED (user)") + return + raise RuntimeError(f"Not allowed access detected: {field}") + elif isinstance(acl.target, ACL_target_group): + if group == acl.target or acl.target == ACL_target_group_Any(): + if acl.rule is ACL_rule.ALLOW: + print("ALLOWED (group)") + return + raise RuntimeError(f"Not allowed access detected: {field}") + else: + raise RuntimeError(f"Wrong ACL target type: {field}") + print("ALLOWED (Default)") + + def check_acl_access(self, request: RestRequest) -> None: + """Check ACL on requested field access""" + self._check_acl(request.user, request.group, request.get_verb(), request.get_resource_origin(0)) + + def check_acl_operation(self, request: RestRequest, new_data: Optional[dict[str, _T_SupportedRESTFields]]) -> None: + """Check ACL on requested field operation (involving checking sub-fields)""" + if request.get_verb() is rsrc_verb.GET: + for key in self.model_fields.keys(): + self._check_acl(request.user, request.group, rsrc_verb.GET, key) + elif request.get_verb() is rsrc_verb.PUT: + for key in new_data.keys(): + if key in self.model_fields: + self._check_acl(request.user, request.group, rsrc_verb.PUT, key) + else: + raise RuntimeError("Incompatible verb") def update(self, **new_data): for field, value in new_data.items(): @@ -251,12 +334,11 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): body = await self.read_body(receive) verb = rsrc_verb[scope["method"]] - self._request = None - result = self.process_request( + request: RestRequest = self.process_request( scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8") ) - assert self._request != None + assert request != None status = 200 if verb in (rsrc_verb.POST, rsrc_verb.PUT): @@ -270,7 +352,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): ], } - for name, value in self._request.outgoing_cookie.items(): + for name, value in request.outgoing_cookie.items(): header_resp["headers"].append(["Set-Cookie", f"{name}={value}"]) # print("----SENT HEADER ---") @@ -278,8 +360,8 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): await send(header_resp) body = None - if result: - body = result.encode("utf-8") + if request.get_result(): + body = request.get_result().encode("utf-8") await send( { @@ -294,7 +376,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): verb: rsrc_verb = rsrc_verb.GET, data_json: Optional[str] = None, query_string: Optional[str] = None, - ) -> Optional[str]: + ) -> RestRequest: from .rest_resource_handler import ( ResourceHandler, ResourceHandler_RestResourceBase, @@ -304,22 +386,20 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if data_json: data = json.loads(data_json) - ressource: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string) - self._request = ressource.get_request() + ressource_handler: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string) - result = ressource.process_verb() + request: RestRequest = ressource_handler.get_request() + assert request != None + + result = ressource_handler.process_verb() # print("OOO") # print(type(self)._resp_cookies) # print("OOO2") if isinstance(result, RestResourceBase): - # exclude: Optional[dict[str, bool]] = None - # raw_exclude = RestResourceWalker_Root__tree_exclude(result).process() - # exclude = next(iter(raw_exclude.values())) - # return json.dumps(result.model_dump(mode="json", exclude=exclude)) - return json.dumps(result.model_dump(mode="json")) + request.set_result(json.dumps(result.model_dump(mode="json"))) + elif result is not None: + request.set_result(json.dumps(result, cls=_JSONEncoder)) - if result is not None: - return json.dumps(result, cls=_JSONEncoder) - return None + return request diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index 218f412..5996cc6 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -19,6 +19,15 @@ from .rest_resource_plugin import ( ResourcePlugin_RestResourceBase, ) +from .rest_ACL import ( + ACL_target_user, + ACL_target_group, + ACL_target_user_Annonymous, + ACL_target_group_Annonymous, + ACL_target_group_Any, + ACL_rule, +) + from .rest_request_opt import ( RestRequestParams_POST, RestRequestParams_DELETE, @@ -174,7 +183,7 @@ class ResourceHandler( # reveal_type(next_resource) _next_resource = cast(_T_Resource, next_resource) # reveal_type(_next_resource) - # print(f"[DEBUG] next_resource = {type(next_resource).__name__}") + print(f"[DEBUG] next_resource = {type(next_resource).__name__}") if ( isinstance(_next_resource, RestResourceBase) @@ -194,7 +203,7 @@ class ResourceHandler( self.next_handler = next_resource_handler return next_resource_handler - # in the context of _find_resource, only resource real values can be retrieved + # in _find_resource context, only resource's real values can be retrieved raise RuntimeError("Wrong request") def _check_access_rights(self): @@ -455,28 +464,43 @@ class ResourceHandler_RestResourceBase( def _check_access_rights(self) -> None: super()._check_access_rights() - # print(f"{type(self).__name__}->_check_access_rights()") + print(f"{type(self).__name__}->_check_access_rights()") if self.req.get_resource_origin(0) == "/": return - # print("======") - # print(self.req.get_resource_origin(0)) + print("==================") + print(self.req.get_resource_origin(0)) # print(len(self.req.get_url_stack())) # print(self.resource._model_dump_excluded_) # print(type(self.resource)) # print(self.resource.exclude) if self.req.get_resource_origin(0) not in self.resource.model_fields: - raise RuntimeError(f"Unknown or not allowed field access detected: {self.req.get_url_stack()}") + raise RuntimeError(f"Unknown field access detected: {self.req.get_url_stack()}") + + self.resource.check_acl_access(self.req) + + if len(self.req.get_url_stack()) == 0: # destination reached + if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET: + raise RuntimeError(f"Not allowed READ access detected: {self.req.get_url_stack()}") + """ # not sure init_var has the expected behavior (read_only) + if self.resource.model_fields[self.req.get_resource_origin(0)].init_var is True and self.req.get_verb() in [ + rsrc_verb.POST, + rsrc_verb.PUT, + rsrc_verb.DELETE, + ]: + raise RuntimeError(f"Not allowed WRITE access detected: {self.req.get_url_stack()}") + """ def _handle_process_get(self, params) -> RestResourceBase: # print(f"{type(self).__name__}->_process_get()") # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") - # CASE 1: no more item in url_stack => we reached the endpoint + # CASE 1: no more item in url_stack => we reached the endpoint (operation) # So we are in a RestResourceBase instance and must return the content if len(self.req.get_url_stack()) == 0: + self.resource.check_acl_operation(self.req) for key, attr in self.resource.model_fields.items(): if key in self.resource._plugins_: if isinstance(self.resource._plugins_[key], ResourcePlugin_field): @@ -492,12 +516,12 @@ class ResourceHandler_RestResourceBase( # print(result) return self.resource - # CASE 2: specific case for root Node + # CASE 2: specific (operation) case for root Node # TODO: this must probably be merged with the previous bloc if self.req.get_resource_origin(0) == "/": return self.resource - # CASE 3: in between + # CASE 3: in between (access) value = getattr(self.resource, self.req.get_resource_origin(0)) key = self.req.get_resource_origin(0) @@ -522,6 +546,8 @@ class ResourceHandler_RestResourceBase( # print(f"{type(self).__name__}->_process_put()") # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + self.resource.check_acl_operation(self.req, self.req.get_data()) + # creating a copy of the current resource _new_resrc = self.resource.copy() # updating values based on nex data diff --git a/src/pyrestresource/rest_types.py b/src/pyrestresource/rest_types.py index 17a95bd..55e7c12 100644 --- a/src/pyrestresource/rest_types.py +++ b/src/pyrestresource/rest_types.py @@ -39,10 +39,9 @@ _T_SupportedRESTFields = [ Path, IPv4Address, IPv4Network, + NoneType, ] -T_SupportedRESTFields = Union[ - UUID, str, int, float, bool, bytes, datetime, Path, IPv4Address, IPv4Network -] +T_SupportedRESTFields = Union[UUID, str, int, float, bool, bytes, datetime, Path, IPv4Address, IPv4Network, NoneType] TV_SupportedRESTFields = TypeVar( "TV_SupportedRESTFields", UUID, @@ -55,6 +54,7 @@ TV_SupportedRESTFields = TypeVar( Path, IPv4Address, IPv4Network, + NoneType, ) if get_origin(T_SupportedRESTFields) is not Union: @@ -68,12 +68,8 @@ T_FieldValue = Union[T_SupportedRESTFields, "RestResourceBase"] T_ListIndex = NewType("T_ListIndex", int) T_ListSize = NewType("T_ListSize", int) -T_DictKey = Union[ - UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network -] # datetime is removed because non-hashable -_T_DictKey = TypeVar( - "_T_DictKey", UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network -) +T_DictKey = Union[UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network] # datetime is removed because non-hashable +_T_DictKey = TypeVar("_T_DictKey", UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network) T_T_DictKey = type[T_DictKey] @@ -92,6 +88,7 @@ _T_DictValues = TypeVar( IPv4Address, IPv4Network, "RestResourceBase", + NoneType, ) T_T_FieldValue = type(T_FieldValue) diff --git a/test/test_rest_login.py b/test/test_rest_login.py index 2fdf8f1..32d2b1f 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -33,6 +33,7 @@ from src.pyrestresource import ( ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, ) +from src.pyrestresource import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule from pprint import pprint testdir_path = Path(__file__).parent.resolve() @@ -76,16 +77,19 @@ def init_classes(): return resource class Login(RestResourceBase): - username: Optional[str] = Field(None, exclude=True) - # username: Optional[str] = Field(None) - secret: Optional[str] = Field(None, exclude=True) + username: Optional[str] = Field(None) + secret: Optional[str] = Field( + None, + exclude=True, + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.GET, rsrc_verb.DELETE, rsrc_verb.POST], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) @register_rest_rootpoint class RootApp(RestResourceBase): - login: Login = Field( - default=Login(), - plugin=ResourcePlugin_Login, - ) + login: Login = Field(default=Login(), plugin=ResourcePlugin_Login) # this add the classes to globals to allow using them later on # => this is only for uinit-testing purpose and is not needed in real use @@ -115,14 +119,37 @@ class Test_RestAPI_LOGIN(unittest.TestCase): self.testapp = RootApp() def test_login(self): + """ result = self.testapp.process_request("/login", rsrc_verb.GET) - print(result) + print("*****************") + print(result.get_result()) + + result = self.testapp.process_request("/login/username", rsrc_verb.GET) + print("*****************") + print(result.get_result()) + + # result = self.testapp.process_request("/login/secret", rsrc_verb.GET) + # print("*****************") + # print(result.get_result()) + """ result = self.testapp.process_request("/login", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') - print(result) + print("*****************") + print(result.get_result()) + """ result = self.testapp.process_request("/login", rsrc_verb.GET) - print(result) + print("*****************") + print(result.get_result()) + + result = self.testapp.process_request("/login/username", rsrc_verb.GET) + print("*****************") + print(result.get_result()) + + # result = self.testapp.process_request("/login/secret", rsrc_verb.GET) + # print("*****************") + # print(result.get_result()) + """ class Test_RestAPI_LOGIN_Web(unittest.TestCase): -- 2.47.3 From 346ff649ecb9f618f0e3eeb8d57d9ae5464cf24a Mon Sep 17 00:00:00 2001 From: cclecle Date: Fri, 3 Nov 2023 17:42:34 +0000 Subject: [PATCH 07/20] fix ACL + cleaning --- src/pyrestresource/rest_request.py | 7 +- src/pyrestresource/rest_resource.py | 39 +++-- src/pyrestresource/rest_resource_handler.py | 24 +-- src/pyrestresource/rest_types.py | 1 - test/test_ACL.py | 180 ++++++++++++++++++++ test/test_rest_login.py | 16 +- test/test_rest_resource.py | 147 ++++++++-------- test/test_rest_resource_plugins.py | 52 +++--- test/test_rest_resource_walker.py | 4 +- test/test_rest_resource_walker_tree.py | 5 +- test/test_rest_webserver.py | 2 - 11 files changed, 332 insertions(+), 145 deletions(-) create mode 100644 test/test_ACL.py diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index 80f223d..4549d0e 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -10,8 +10,9 @@ from re import sub from urllib.parse import urlparse, parse_qs from pydantic import BaseModel, Field +from typeguard import check_type -from .rest_types import rsrc_verb, T_SupportedRESTFields +from .rest_types import rsrc_verb, T_SupportedRESTFields, T_AllSupportedFields from .rest_request_opt import ( RestRequestParams_POST, @@ -129,6 +130,10 @@ class RestRequest(Generic[_T_RestRequestParams]): raise RuntimeError("url and verb and data must be set") self.url = url self.verb = verb + + if data != {} and not check_type(data, T_AllSupportedFields): + raise RuntimeError(f"Wrong data type received: {data}") + self.data = data # parse_qs returns list[] for all keys, the command convert list to single items so pydantic can eat them :) diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index 8a70a0e..cfad675 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -142,8 +142,8 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): if "ACL" in self.resource.json_schema_extra: if isinstance(self.resource.json_schema_extra["ACL"], list): - print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}") - self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"] + # print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] else: raise RuntimeError("ACL must be a list()") @@ -174,6 +174,9 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi if self.parent.annotation._primary_key_ is not None: raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}") self.parent.annotation._primary_key_ = self.resource_name + self.parent.annotation._ACL_record_[self.resource_name] = [ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY) + ] if "plugin" in self.resource.json_schema_extra: plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"] @@ -184,8 +187,8 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi if "ACL" in self.resource.json_schema_extra: if isinstance(self.resource.json_schema_extra["ACL"], list): - print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}") - self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"] + # print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] else: raise RuntimeError("ACL must be a list()") @@ -223,8 +226,8 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_ if "ACL" in self.resource.json_schema_extra: if isinstance(self.resource.json_schema_extra["ACL"], list): - print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}") - self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"] + # print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] else: raise RuntimeError("ACL must be a list()") @@ -261,35 +264,35 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): ] ] = {} - def _check_acl(self, user: ACL_target_user, group: ACL_target_group, verb: rsrc_verb, field: str): - print(f"evaluate self ACLs rule: {self._ACL_record_}") - if verb is rsrc_verb.GET and self.model_fields[field].exclude is True: - print("ALLOWED (excluded field)") + def _check_acl(self, user: ACL_target_user, group: ACL_target_group, verb: rsrc_verb, field: str, is_self: bool = True): + # print(f"evaluate self ACLs rule: {self._ACL_record_}") + if is_self and verb is rsrc_verb.GET and self.model_fields[field].exclude is True: + # print("ALLOWED (excluded field)") return for acl in self._ACL_record_[field]: - print(f"evaluate ACL rule: {acl}") + # print(f"evaluate ACL rule: {acl}") if verb in acl.verbs: if isinstance(acl.target, ACL_target_user): if user == acl.target: if acl.rule is ACL_rule.ALLOW: - print("ALLOWED (user)") + # print("ALLOWED (user)") return raise RuntimeError(f"Not allowed access detected: {field}") elif isinstance(acl.target, ACL_target_group): if group == acl.target or acl.target == ACL_target_group_Any(): if acl.rule is ACL_rule.ALLOW: - print("ALLOWED (group)") + # print("ALLOWED (group)") return raise RuntimeError(f"Not allowed access detected: {field}") else: raise RuntimeError(f"Wrong ACL target type: {field}") - print("ALLOWED (Default)") + # print("ALLOWED (Default)") - def check_acl_access(self, request: RestRequest) -> None: + def check_acl_field(self, request: RestRequest, req_index: int = 0) -> None: """Check ACL on requested field access""" - self._check_acl(request.user, request.group, request.get_verb(), request.get_resource_origin(0)) + self._check_acl(request.user, request.group, request.get_verb(), request.get_resource_origin(req_index), False) - def check_acl_operation(self, request: RestRequest, new_data: Optional[dict[str, _T_SupportedRESTFields]]) -> None: + def check_acl_self(self, request: RestRequest, new_data: Optional[dict[str, _T_SupportedRESTFields]]) -> None: """Check ACL on requested field operation (involving checking sub-fields)""" if request.get_verb() is rsrc_verb.GET: for key in self.model_fields.keys(): @@ -401,5 +404,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): request.set_result(json.dumps(result.model_dump(mode="json"))) elif result is not None: request.set_result(json.dumps(result, cls=_JSONEncoder)) + else: + request.set_result("null") return request diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index 5996cc6..01ecc0b 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -183,7 +183,7 @@ class ResourceHandler( # reveal_type(next_resource) _next_resource = cast(_T_Resource, next_resource) # reveal_type(_next_resource) - print(f"[DEBUG] next_resource = {type(next_resource).__name__}") + # print(f"[DEBUG] next_resource = {type(next_resource).__name__}") if ( isinstance(_next_resource, RestResourceBase) @@ -464,13 +464,13 @@ class ResourceHandler_RestResourceBase( def _check_access_rights(self) -> None: super()._check_access_rights() - print(f"{type(self).__name__}->_check_access_rights()") + # print(f"{type(self).__name__}->_check_access_rights()") if self.req.get_resource_origin(0) == "/": return - print("==================") - print(self.req.get_resource_origin(0)) + # print("==================") + # print(self.req.get_resource_origin(0)) # print(len(self.req.get_url_stack())) # print(self.resource._model_dump_excluded_) # print(type(self.resource)) @@ -479,7 +479,7 @@ class ResourceHandler_RestResourceBase( if self.req.get_resource_origin(0) not in self.resource.model_fields: raise RuntimeError(f"Unknown field access detected: {self.req.get_url_stack()}") - self.resource.check_acl_access(self.req) + self.resource.check_acl_field(self.req) if len(self.req.get_url_stack()) == 0: # destination reached if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET: @@ -500,7 +500,7 @@ class ResourceHandler_RestResourceBase( # CASE 1: no more item in url_stack => we reached the endpoint (operation) # So we are in a RestResourceBase instance and must return the content if len(self.req.get_url_stack()) == 0: - self.resource.check_acl_operation(self.req) + self.resource.check_acl_self(self.req, None) for key, attr in self.resource.model_fields.items(): if key in self.resource._plugins_: if isinstance(self.resource._plugins_[key], ResourcePlugin_field): @@ -522,6 +522,7 @@ class ResourceHandler_RestResourceBase( return self.resource # CASE 3: in between (access) + self.resource.check_acl_field(self.req) value = getattr(self.resource, self.req.get_resource_origin(0)) key = self.req.get_resource_origin(0) @@ -546,7 +547,7 @@ class ResourceHandler_RestResourceBase( # print(f"{type(self).__name__}->_process_put()") # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") - self.resource.check_acl_operation(self.req, self.req.get_data()) + self.resource.check_acl_self(self.req, self.req.get_data()) # creating a copy of the current resource _new_resrc = self.resource.copy() @@ -564,9 +565,8 @@ class ResourceHandler_RestResourceBase( # applying plugins (from parent element) if self.prev_handler is not None: - # element is within a dict if ( - isinstance(self.prev_handler.resource, dict) + isinstance(self.prev_handler.resource, dict) # element is within a dict and self.prev_handler.prev_handler is not None and isinstance(self.prev_handler.prev_handler.resource, RestResourceBase) ): @@ -587,7 +587,7 @@ class ResourceHandler_RestResourceBase( ) _new_resrc = plugin_rsrc.handle_resource_put(_new_resrc, params) - self.resource.update(**_new_resrc.dict()) + self.resource.update(**_new_resrc.__dict__) return def _handle_process_delete(self, params) -> None: @@ -629,6 +629,8 @@ class ResourceHandler_simple( assert self.prev_handler is not None assert isinstance(self.prev_handler.resource, RestResourceBase) + self.prev_handler.resource.check_acl_field(self.req, 1) + if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, @@ -645,6 +647,8 @@ class ResourceHandler_simple( assert self.prev_handler is not None assert isinstance(self.prev_handler.resource, RestResourceBase) + self.prev_handler.resource.check_acl_field(self.req, 1) + value = self.req.get_data() if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: diff --git a/src/pyrestresource/rest_types.py b/src/pyrestresource/rest_types.py index 55e7c12..98edc7c 100644 --- a/src/pyrestresource/rest_types.py +++ b/src/pyrestresource/rest_types.py @@ -98,5 +98,4 @@ T_Dict = dict[T_DictKey, T_DictValues] _T_Dict = dict[_T_DictKey, _T_DictValues] T_AllSupportedFields = T_Dict | T_FieldValue -T_AllSupportedFiels = T_Dict | T_FieldValue T_AllSupportedContainers = Union[T_Dict, "RestResourceBase"] diff --git a/test/test_ACL.py b/test/test_ACL.py new file mode 100644 index 0000000..17859e3 --- /dev/null +++ b/test/test_ACL.py @@ -0,0 +1,180 @@ +from __future__ import annotations +import unittest +from os import chdir +from pathlib import Path +from typing import Optional +from pydantic import Field + + +print(__name__) +print(__package__) + + +from src.pyrestresource import ( + register_rest_rootpoint, + RestResourceBase, + rsrc_verb, + RestRequestParams_GET, + RestRequestParams_POST, + RestRequestParams_Dict_GET, + RestRequestParams_PUT, + T_SupportedRESTFields, + ResourcePlugin_field_default, + ResourcePlugin_RestResourceBase_default, + ACL_target_group_Any, + ACL_record, + ACL_rule, +) + + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +# to allow mock-ing, all the tested classes are in a function +def init_classes(): + class TestResource(RestResourceBase): + username: Optional[str] = Field(None) + secret: Optional[str] = Field( + None, + exclude=True, + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + + class TestResource2(RestResourceBase): + version_ro: Optional[str] = Field( + "1.2.3", + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + version: Optional[str] = Field("3.2.1") + + @register_rest_rootpoint + class RootApp(RestResourceBase): + resource_with_secret: TestResource = Field(default=TestResource()) + resource_with_secret_ACL: TestResource = Field( + default=TestResource(), ACL=[ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY)] + ) + resource2: TestResource2 = Field(TestResource2()) + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[TestResource.__name__] = TestResource + globals()[RootApp.__name__] = RootApp + + +class Test_RestAPI_ACL(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_subresource_readonly(self): + result = self.testapp.process_request("/", rsrc_verb.GET) + self.assertEqual(result.get_result(), "{}") + + result = self.testapp.process_request("/resource2", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "3.2.1"}') + + self.testapp.process_request("/resource2/version", rsrc_verb.PUT, '"6.6.6"') + + result = self.testapp.process_request("/resource2", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') + + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request("/resource2/version_ro", rsrc_verb.PUT, '"6.6.6"') + + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request("/resource2", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}') + + result = self.testapp.process_request("/resource2", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') + + def test_subresource(self): + result = self.testapp.process_request("/", rsrc_verb.GET) + self.assertEqual(result.get_result(), "{}") + + result = self.testapp.process_request("/resource_with_secret", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"username": null}') + + result = self.testapp.process_request("/resource_with_secret/username", rsrc_verb.GET) + self.assertEqual(result.get_result(), "null") + self.assertEqual(self.testapp.resource_with_secret.username, None) + + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) + + self.assertEqual(self.testapp.resource_with_secret.secret, None) + + result = self.testapp.process_request("/resource_with_secret", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') + self.assertEqual(result.get_result(), "null") + + result = self.testapp.process_request("/resource_with_secret", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"username": "chacha"}') + + result = self.testapp.process_request("/resource_with_secret/username", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"chacha"') + self.assertEqual(self.testapp.resource_with_secret.username, "chacha") + + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) + + self.assertEqual(self.testapp.resource_with_secret.secret, "123456") + + def test_subresource_field(self): + result = self.testapp.process_request("/resource_with_secret/username", rsrc_verb.PUT, '"chacha"') + self.assertEqual(result.get_result(), "null") + + result = self.testapp.process_request("/resource_with_secret", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"username": "chacha"}') + + result = self.testapp.process_request("/resource_with_secret/username", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"chacha"') + self.assertEqual(self.testapp.resource_with_secret.username, "chacha") + + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) + + result = self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.PUT, '"123456"') + self.assertEqual(result.get_result(), "null") + + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) + + self.assertEqual(self.testapp.resource_with_secret.secret, "123456") + + def test_subresource_ACL(self): + result = self.testapp.process_request("/", rsrc_verb.GET) + self.assertEqual(result.get_result(), "{}") + + result = self.testapp.process_request("/resource_with_secret_ACL", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"username": null}') + + result = self.testapp.process_request("/resource_with_secret_ACL/username", rsrc_verb.GET) + self.assertEqual(result.get_result(), "null") + self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) + + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request("/resource_with_secret_ACL/secret", rsrc_verb.GET) + + self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) + + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request("/resource_with_secret_ACL", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') + self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) + self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) + + def test_subresource_ACL_field(self): + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request("/resource_with_secret_ACL/username", rsrc_verb.PUT, '"chacha"') + self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) + self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) + + with self.assertRaises(RuntimeError): # TODO: custom exception + self.testapp.process_request("/resource_with_secret_ACL/secret", rsrc_verb.PUT, '"123456"') + self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) + self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) diff --git a/test/test_rest_login.py b/test/test_rest_login.py index 32d2b1f..ff71667 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -7,8 +7,6 @@ from typing import Optional, Annotated from pydantic import Field from uuid import UUID, uuid4 from time import time, sleep -from time import time -import json import uvicorn import socket import requests @@ -32,9 +30,11 @@ from src.pyrestresource import ( T_SupportedRESTFields, ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, + ACL_target_group_Any, + ACL_record, + ACL_rule, ) -from src.pyrestresource import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule -from pprint import pprint + testdir_path = Path(__file__).parent.resolve() chdir(testdir_path.parent.resolve()) @@ -83,7 +83,7 @@ def init_classes(): exclude=True, ACL=[ ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), - ACL_record(verbs=[rsrc_verb.GET, rsrc_verb.DELETE, rsrc_verb.POST], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY), ], ) @@ -101,8 +101,6 @@ def find_free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - hostname = socket.gethostname() - IPAddr = socket.gethostbyname(hostname) return "localhost", s.getsockname()[1] @@ -119,7 +117,6 @@ class Test_RestAPI_LOGIN(unittest.TestCase): self.testapp = RootApp() def test_login(self): - """ result = self.testapp.process_request("/login", rsrc_verb.GET) print("*****************") print(result.get_result()) @@ -131,13 +128,11 @@ class Test_RestAPI_LOGIN(unittest.TestCase): # result = self.testapp.process_request("/login/secret", rsrc_verb.GET) # print("*****************") # print(result.get_result()) - """ result = self.testapp.process_request("/login", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') print("*****************") print(result.get_result()) - """ result = self.testapp.process_request("/login", rsrc_verb.GET) print("*****************") print(result.get_result()) @@ -149,7 +144,6 @@ class Test_RestAPI_LOGIN(unittest.TestCase): # result = self.testapp.process_request("/login/secret", rsrc_verb.GET) # print("*****************") # print(result.get_result()) - """ class Test_RestAPI_LOGIN_Web(unittest.TestCase): diff --git a/test/test_rest_resource.py b/test/test_rest_resource.py index 5120150..f3b7c9d 100644 --- a/test/test_rest_resource.py +++ b/test/test_rest_resource.py @@ -21,6 +21,9 @@ from src.pyrestresource import ( RestRequestParams_POST, RestRequestParams_Dict_GET, T_SupportedRESTFields, + ACL_target_group_Any, + ACL_record, + ACL_rule, ) from pprint import pprint @@ -58,9 +61,19 @@ def init_classes(): Patch_2 = Patch(uuid="d385a1d2-65fa-11ee-8c99-0242ac120002", shortname="testPatch2") class User(RestResourceBase): - uuid: UUID = Field(default_factory=uuid4, primary_key=True) + uuid: UUID = Field( + default_factory=uuid4, + primary_key=True, + ) name: str - secret: str = Field(..., exclude=True) + secret: str = Field( + ..., + exclude=True, + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) User1 = User( uuid="8da57a3c-661f-11ee-8c99-0242ac120002", @@ -68,8 +81,6 @@ def init_classes(): secret="la blanquette est bonne", ) - ext_patchs: dict[UUID, Patch] = {} - class Patch2(RestResourceBase): uuid: UUID = Field(default_factory=uuid4, primary_key=True) shortname: str @@ -117,100 +128,100 @@ class Test_RestAPI_GET(unittest.TestCase): def test_get_root(self): result = self.testapp.process_request("/", rsrc_verb.GET) - self.assertEqual(result, '{"testValueRoot": 3.14}') + self.assertEqual(result.get_result(), '{"testValueRoot": 3.14}') def test_get_root__multiple_slash(self): result = self.testapp.process_request("/////", rsrc_verb.GET) - self.assertEqual(result, '{"testValueRoot": 3.14}') + self.assertEqual(result.get_result(), '{"testValueRoot": 3.14}') result = self.testapp.process_request("////", rsrc_verb.GET) - self.assertEqual(result, '{"testValueRoot": 3.14}') + self.assertEqual(result.get_result(), '{"testValueRoot": 3.14}') def test_get_root__nested_value(self): result = self.testapp.process_request("/testValueRoot", rsrc_verb.GET) - self.assertEqual(result, "3.14") + self.assertEqual(result.get_result(), "3.14") def test_get_root__nested_value__trailing_slash(self): result = self.testapp.process_request("/testValueRoot/", rsrc_verb.GET) - self.assertEqual(result, "3.14") + self.assertEqual(result.get_result(), "3.14") result = self.testapp.process_request("/testValueRoot//", rsrc_verb.GET) - self.assertEqual(result, "3.14") + self.assertEqual(result.get_result(), "3.14") result = self.testapp.process_request("/testValueRoot///", rsrc_verb.GET) - self.assertEqual(result, "3.14") + self.assertEqual(result.get_result(), "3.14") def test_get_root__nested_value__multiple_slash(self): result = self.testapp.process_request("//testValueRoot", rsrc_verb.GET) - self.assertEqual(result, "3.14") + self.assertEqual(result.get_result(), "3.14") result = self.testapp.process_request("///testValueRoot", rsrc_verb.GET) - self.assertEqual(result, "3.14") + self.assertEqual(result.get_result(), "3.14") def test_get_version(self): result = self.testapp.process_request("/info", rsrc_verb.GET) - self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') def test_get_version__trailing_slash(self): result = self.testapp.process_request("/info/", rsrc_verb.GET) - self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') result = self.testapp.process_request("/info//", rsrc_verb.GET) - self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') result = self.testapp.process_request("/info///", rsrc_verb.GET) - self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') def test_get_version__multiple_slash(self): result = self.testapp.process_request("//info", rsrc_verb.GET) - self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') result = self.testapp.process_request("///info", rsrc_verb.GET) - self.assertEqual(result, '{"version": "0.0.1", "api_version": "0.0.2"}') + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') def test_get_version__nested_value(self): result = self.testapp.process_request("/info/api_version", rsrc_verb.GET) - self.assertEqual(result, '"0.0.2"') + self.assertEqual(result.get_result(), '"0.0.2"') result = self.testapp.process_request("/info/version", rsrc_verb.GET) - self.assertEqual(result, '"0.0.1"') + self.assertEqual(result.get_result(), '"0.0.1"') def test_get_dict_games(self): result = self.testapp.process_request("/games", rsrc_verb.GET) - self.assertEqual(result, '["9b0381d4-65f6-11ee-8c99-0242ac120002"]') + self.assertEqual(result.get_result(), '["9b0381d4-65f6-11ee-8c99-0242ac120002"]') def test_get_dict_patchs(self): result = self.testapp.process_request("/patchs", rsrc_verb.GET) self.assertEqual( - result, + result.get_result(), '["cee1e870-65fa-11ee-8c99-0242ac120002", "d385a1d2-65fa-11ee-8c99-0242ac120002"]', ) def test_get_dict_patch_element(self): result = self.testapp.process_request("/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", rsrc_verb.GET) self.assertEqual( - result, + result.get_result(), '{"uuid": "cee1e870-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch1", "name": null, "description": null}', ) def test_get_dict_game_element(self): result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002", rsrc_verb.GET) expected = '{"uuid": "9b0381d4-65f6-11ee-8c99-0242ac120002", "shortname": "testGame", "name": null, "description": null}' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_get_dict_game_element__nested_value(self): result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname", rsrc_verb.GET) expected = '"testGame"' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_get_dict_game_element__nested_value2(self): result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/uuid", rsrc_verb.GET) expected = '"9b0381d4-65f6-11ee-8c99-0242ac120002"' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_get_nested_dict_games_patchs(self): result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) - self.assertEqual(result, '["cee1e870-65fa-11ee-8c99-0242ac120002"]') + self.assertEqual(result.get_result(), '["cee1e870-65fa-11ee-8c99-0242ac120002"]') def test_get_nested_dict_games_patch_element(self): result = self.testapp.process_request( @@ -218,28 +229,28 @@ class Test_RestAPI_GET(unittest.TestCase): rsrc_verb.GET, ) expected = '{"uuid": "cee1e870-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch1", "name": null, "description": null}' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_get_nested_dict_games_patch_element__nested_value(self): result = self.testapp.process_request( "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/uuid", rsrc_verb.GET, ) - self.assertEqual(result, '"cee1e870-65fa-11ee-8c99-0242ac120002"') + self.assertEqual(result.get_result(), '"cee1e870-65fa-11ee-8c99-0242ac120002"') def test_get_dict_game_element__API_nested(self): result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002?API_nested=True", rsrc_verb.GET) expected = '{"uuid": "9b0381d4-65f6-11ee-8c99-0242ac120002", "shortname": "testGame", "name": null, "description": null}' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_get_dict_users(self): result = self.testapp.process_request("/users", rsrc_verb.GET) - self.assertEqual(result, '["8da57a3c-661f-11ee-8c99-0242ac120002"]') + self.assertEqual(result.get_result(), '["8da57a3c-661f-11ee-8c99-0242ac120002"]') def test_get_dict_user_element(self): result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.GET) self.assertEqual( - result, + result.get_result(), '{"uuid": "8da57a3c-661f-11ee-8c99-0242ac120002", "name": "chacha"}', "no secret seen", ) @@ -247,14 +258,14 @@ class Test_RestAPI_GET(unittest.TestCase): def test_get_dict_user_element2(self): result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002?API_nested=True", rsrc_verb.GET) self.assertEqual( - result, + result.get_result(), '{"uuid": "8da57a3c-661f-11ee-8c99-0242ac120002", "name": "chacha"}', "no secret seen", ) def test_get_dict_user_element__nested_value(self): result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002/name", rsrc_verb.GET) - self.assertEqual(result, '"chacha"') + self.assertEqual(result.get_result(), '"chacha"') def test_get_dict_user_element__nested_value__forbiden(self): with self.assertRaises(RuntimeError): # TODO: custom exception @@ -278,7 +289,7 @@ class Test_RestAPI_PUT(unittest.TestCase): self.testapp.process_request("/info", rsrc_verb.PUT, '{"version": "1.2.3", "api_version": "3.2.1"}') result = self.testapp.process_request("/info", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.2.3", "api_version": "3.2.1"}') + self.assertEqual(result.get_result(), '{"version": "1.2.3", "api_version": "3.2.1"}') def test_put_dict_user_nested_value(self): self.testapp.process_request( @@ -288,12 +299,12 @@ class Test_RestAPI_PUT(unittest.TestCase): ) result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002/name", rsrc_verb.GET) - self.assertEqual(result, '"chacha2"') + self.assertEqual(result.get_result(), '"chacha2"') def test_put_user_nested_value__forbiden(self): with self.assertRaises(RuntimeError): # TODO: custom exception self.testapp.process_request( - "/users/8da57a3c-661f-11ee-8c99-0242ac120002/secret", + "/users/8da57a3c-661f-11ee-8c99-0242ac120002/uuid", rsrc_verb.PUT, '"test"', ) @@ -307,11 +318,11 @@ class Test_RestAPI_PUT(unittest.TestCase): result = self.testapp.process_request("/users", rsrc_verb.GET) expected = '["8da57a3c-661f-11ee-8c99-0242ac120002"]' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.GET) expected = '{"uuid": "8da57a3c-661f-11ee-8c99-0242ac120002", "name": "testUser4"}' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_put_dict_patch__nested(self): self.testapp.process_request( @@ -325,7 +336,7 @@ class Test_RestAPI_PUT(unittest.TestCase): rsrc_verb.GET, ) expected = '{"uuid": "cee1e870-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch998", "name": "MyPatch", "description": "MyDescription123"}' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) class Test_RestAPI_POST(unittest.TestCase): @@ -340,15 +351,15 @@ class Test_RestAPI_POST(unittest.TestCase): rsrc_verb.POST, '{"name": "testUser", "secret": "test"}', ) - self.assertEqual(result, '"e5e87d32-662b-11ee-8c99-0242ac120002"') + self.assertEqual(result.get_result(), '"e5e87d32-662b-11ee-8c99-0242ac120002"') result = self.testapp.process_request("/users", rsrc_verb.GET) expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "e5e87d32-662b-11ee-8c99-0242ac120002"]' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) result = self.testapp.process_request("/users/e5e87d32-662b-11ee-8c99-0242ac120002", rsrc_verb.GET) expected = '{"uuid": "e5e87d32-662b-11ee-8c99-0242ac120002", "name": "testUser"}' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_post_dict_user__nested_key(self): result = self.testapp.process_request( @@ -356,15 +367,15 @@ class Test_RestAPI_POST(unittest.TestCase): rsrc_verb.POST, '{"name": "testUser2", "secret": "test", "uuid":"e7e86d32-662b-11ee-8c99-0242ac120002"}', ) - self.assertEqual(result, '"e7e86d32-662b-11ee-8c99-0242ac120002"') + self.assertEqual(result.get_result(), '"e7e86d32-662b-11ee-8c99-0242ac120002"') result = self.testapp.process_request("/users", rsrc_verb.GET) expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "e7e86d32-662b-11ee-8c99-0242ac120002"]' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) result = self.testapp.process_request("/users/e7e86d32-662b-11ee-8c99-0242ac120002", rsrc_verb.GET) expected = '{"uuid": "e7e86d32-662b-11ee-8c99-0242ac120002", "name": "testUser2"}' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) @patch(f"{__loader__.name }.uuid4") def test_post_dict_user__auto_key(self, mock_uuid4): @@ -375,15 +386,15 @@ class Test_RestAPI_POST(unittest.TestCase): self.testapp = RootApp() result = self.testapp.process_request("/users", rsrc_verb.POST, '{"name": "testUser3", "secret": "test"}') - self.assertEqual(result, '"5faccb2e-69aa-11ee-8c99-0242ac120002"') + self.assertEqual(result.get_result(), '"5faccb2e-69aa-11ee-8c99-0242ac120002"') result = self.testapp.process_request("/users", rsrc_verb.GET) expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "5faccb2e-69aa-11ee-8c99-0242ac120002"]' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) result = self.testapp.process_request("/users/5faccb2e-69aa-11ee-8c99-0242ac120002", rsrc_verb.GET) expected = '{"uuid": "5faccb2e-69aa-11ee-8c99-0242ac120002", "name": "testUser3"}' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_post_dict_patch__nested_API_key(self): self.testapp.process_request( @@ -397,7 +408,7 @@ class Test_RestAPI_POST(unittest.TestCase): rsrc_verb.GET, ) expected = '{"uuid": "cee1e971-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch99", "name": "MyPatch", "description": "MyDescription"}' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) class Test_RestAPI_DELETE(unittest.TestCase): @@ -411,7 +422,7 @@ class Test_RestAPI_DELETE(unittest.TestCase): result = self.testapp.process_request("/users", rsrc_verb.GET) expected = "[]" - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_delete_dict_user__All(self): result = self.testapp.process_request( @@ -419,24 +430,24 @@ class Test_RestAPI_DELETE(unittest.TestCase): rsrc_verb.POST, '{"name": "testUser", "secret": "test"}', ) - self.assertEqual(result, '"e5e87d32-662b-11ee-8c99-0242ac120002"') + self.assertEqual(result.get_result(), '"e5e87d32-662b-11ee-8c99-0242ac120002"') result = self.testapp.process_request("/users", rsrc_verb.GET) expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "e5e87d32-662b-11ee-8c99-0242ac120002"]' - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) self.testapp.process_request("/users", rsrc_verb.DELETE) result = self.testapp.process_request("/users", rsrc_verb.GET) expected = "[]" - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_delete_dict_user_element(self): self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.DELETE) result = self.testapp.process_request("/users", rsrc_verb.GET) expected = "[]" - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_delete_nested_dict_games_patch_element(self): self.testapp.process_request( @@ -446,7 +457,7 @@ class Test_RestAPI_DELETE(unittest.TestCase): result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) expected = "[]" - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_delete_nested_dict_games_patch_API_key(self): self.testapp.process_request( @@ -456,14 +467,14 @@ class Test_RestAPI_DELETE(unittest.TestCase): result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) expected = "[]" - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) def test_delete_nested_dict_games_patch_All(self): self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.DELETE) result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) expected = "[]" - self.assertEqual(result, expected) + self.assertEqual(result.get_result(), expected) class Test_RestAPI_PERFO(unittest.TestCase): @@ -478,13 +489,13 @@ class Test_RestAPI_PERFO(unittest.TestCase): n_loop = 10000 start = time() - for i in range(n_loop): + for _ in range(n_loop): self.testapp.process_request(f"/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.GET) end = time() print(f"GET 1st level dict: {int(n_loop/(end-start))} Req/s") start = time() - for i in range(n_loop): + for _ in range(n_loop): newUUID = uuid4() self.testapp.process_request( f"/users?API_key={newUUID}", @@ -495,7 +506,7 @@ class Test_RestAPI_PERFO(unittest.TestCase): print(f"POST 1st level dict (API_key): {int(n_loop/(end-start))} Req/s") start = time() - for i in range(n_loop): + for _ in range(n_loop): newUUID = uuid4() self.testapp.process_request( f"/users?API_key={newUUID}", @@ -507,14 +518,14 @@ class Test_RestAPI_PERFO(unittest.TestCase): print(f"POST/GET 1st level dict (API_key): {int(n_loop/(end-start))} Req/s") start = time() - for i in range(n_loop): + for _ in range(n_loop): result = self.testapp.process_request(f"/users", rsrc_verb.POST, '{"name": "testUser", "secret": "test"}') - self.testapp.process_request(f"/users/{json.loads(result)}", rsrc_verb.GET) + self.testapp.process_request(f"/users/{json.loads(result.get_result())}", rsrc_verb.GET) end = time() print(f"POST/GET 1st level dict (autokey): {int(n_loop/(end-start))} Req/s") start = time() - for i in range(n_loop): + for _ in range(n_loop): self.testapp.process_request( f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname", rsrc_verb.PUT, @@ -525,7 +536,7 @@ class Test_RestAPI_PERFO(unittest.TestCase): print(f"PUT/GET 1st level (value) dict: {int(n_loop/(end-start))} Req/s") start = time() - for i in range(n_loop): + for _ in range(n_loop): self.testapp.process_request( f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", rsrc_verb.GET, @@ -534,7 +545,7 @@ class Test_RestAPI_PERFO(unittest.TestCase): print(f"GET 2nd level dict: {int(n_loop/(end-start))} Req/s") start = time() - for i in range(n_loop): + for _ in range(n_loop): self.testapp.process_request( f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", rsrc_verb.GET, @@ -543,7 +554,7 @@ class Test_RestAPI_PERFO(unittest.TestCase): print(f"GET 2nd level (value) dict: {int(n_loop/(end-start))} Req/s") start = time() - for i in range(n_loop): + for _ in range(n_loop): self.testapp.process_request( f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", rsrc_verb.PUT, diff --git a/test/test_rest_resource_plugins.py b/test/test_rest_resource_plugins.py index ac40326..cffa355 100644 --- a/test/test_rest_resource_plugins.py +++ b/test/test_rest_resource_plugins.py @@ -1,16 +1,9 @@ from __future__ import annotations import unittest -from unittest.mock import patch from os import chdir from pathlib import Path -from typing import Optional, Annotated +from typing import Annotated from pydantic import Field -from uuid import UUID, uuid4 -from time import time -import json - -print(__name__) -print(__package__) from src.pyrestresource import ( register_rest_rootpoint, @@ -24,7 +17,6 @@ from src.pyrestresource import ( ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, ) -from pprint import pprint testdir_path = Path(__file__).parent.resolve() chdir(testdir_path.parent.resolve()) @@ -133,16 +125,16 @@ class Test_RestAPI_Plugin_PUT(unittest.TestCase): self.testapp.process_request("/info_put/version", rsrc_verb.PUT, '"1.5.6"') result = self.testapp.process_request("/info_put", rsrc_verb.GET) - print(result) + print(result.get_result()) result = self.testapp.process_request("/info_put/version", rsrc_verb.GET) - print(result) - self.assertEqual(result, '"42"') + print(result.get_result()) + self.assertEqual(result.get_result(), '"42"') def test_put_field_version_resourceplugin(self): self.testapp.process_request("/info_put", rsrc_verb.PUT, '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("/info_put", rsrc_verb.GET) - self.assertEqual(result, '{"version": "42", "api_version": "98.321"}') + self.assertEqual(result.get_result(), '{"version": "42", "api_version": "98.321"}') class Test_RestAPI_Plugin_GET(unittest.TestCase): @@ -153,59 +145,59 @@ class Test_RestAPI_Plugin_GET(unittest.TestCase): def test_get_root(self): result = self.testapp.process_request("/", rsrc_verb.GET) - self.assertEqual(result, "{}") + self.assertEqual(result.get_result(), "{}") def test_get_version(self): result = self.testapp.process_request("/info", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("/info2", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') def test_get_version__trailing_slash(self): result = self.testapp.process_request("/info/", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("/info//", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("/info///", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("/info2/", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') result = self.testapp.process_request("/info2//", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') result = self.testapp.process_request("/info2///", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') def test_get_version__multiple_slash(self): result = self.testapp.process_request("//info", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("///info", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("//info2", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') result = self.testapp.process_request("///info2", rsrc_verb.GET) - self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}') + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') def test_get_version__nested_value(self): result = self.testapp.process_request("/info/api_version", rsrc_verb.GET) - self.assertEqual(result, '"98.321"') + self.assertEqual(result.get_result(), '"98.321"') result = self.testapp.process_request("/info/version", rsrc_verb.GET) - self.assertEqual(result, '"1.5.6"') + self.assertEqual(result.get_result(), '"1.5.6"') result = self.testapp.process_request("/info2/api_version", rsrc_verb.GET) - self.assertEqual(result, '"0.0.3"') + self.assertEqual(result.get_result(), '"0.0.3"') result = self.testapp.process_request("/info2/version", rsrc_verb.GET) - self.assertEqual(result, '"1.5.6"') + self.assertEqual(result.get_result(), '"1.5.6"') def test_defect_plugin_field(self): with self.assertRaises(RuntimeError): diff --git a/test/test_rest_resource_walker.py b/test/test_rest_resource_walker.py index d562dc5..ecaefd2 100644 --- a/test/test_rest_resource_walker.py +++ b/test/test_rest_resource_walker.py @@ -1,13 +1,13 @@ from __future__ import annotations import unittest -from typing import Optional, cast +from typing import Optional from os import chdir from pathlib import Path from pydantic import Field from io import StringIO -from contextlib import redirect_stdout, redirect_stderr +from contextlib import redirect_stdout print(__name__) print(__package__) diff --git a/test/test_rest_resource_walker_tree.py b/test/test_rest_resource_walker_tree.py index 8ae2a48..afaeb40 100644 --- a/test/test_rest_resource_walker_tree.py +++ b/test/test_rest_resource_walker_tree.py @@ -1,13 +1,12 @@ from __future__ import annotations import unittest -from typing import Annotated, Optional +from typing import Optional from os import chdir from pathlib import Path from pydantic import Field -from io import StringIO -from contextlib import redirect_stdout, redirect_stderr + print(__name__) print(__package__) diff --git a/test/test_rest_webserver.py b/test/test_rest_webserver.py index 76033a9..6eebc21 100644 --- a/test/test_rest_webserver.py +++ b/test/test_rest_webserver.py @@ -117,8 +117,6 @@ def find_free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - hostname = socket.gethostname() - IPAddr = socket.gethostbyname(hostname) return "localhost", s.getsockname()[1] -- 2.47.3 From f00cf7b4b2d60fc3eb55240eef931fca77b46277 Mon Sep 17 00:00:00 2001 From: cclecle Date: Sun, 5 Nov 2023 15:38:08 +0000 Subject: [PATCH 08/20] continuing implementation of login and session --- .settings/org.eclipse.core.resources.prefs | 1 + src/pyrestresource/__init__.py | 1 + src/pyrestresource/helpers.py | 14 +++ src/pyrestresource/rest_ACL.py | 4 - src/pyrestresource/rest_login.py | 127 ++++++++++++++++++++ src/pyrestresource/rest_request.py | 32 ++++- src/pyrestresource/rest_resource.py | 54 ++++++--- src/pyrestresource/rest_resource_handler.py | 34 +++--- src/pyrestresource/rest_resource_plugin.py | 37 ++++-- test/test_ACL.py | 16 +-- test/test_rest_login.py | 118 ++++++++++-------- 11 files changed, 325 insertions(+), 113 deletions(-) create mode 100644 src/pyrestresource/rest_login.py diff --git a/.settings/org.eclipse.core.resources.prefs b/.settings/org.eclipse.core.resources.prefs index e658763..c89f8d0 100644 --- a/.settings/org.eclipse.core.resources.prefs +++ b/.settings/org.eclipse.core.resources.prefs @@ -1,6 +1,7 @@ eclipse.preferences.version=1 encoding//src/pyrestresource/__init__.py=utf-8 encoding//src/pyrestresource/__metadata__.py=utf-8 +encoding//src/pyrestresource/rest_login.py=utf-8 encoding//src/pyrestresource/rest_resource.py=utf-8 encoding//src/pyrestresource/rest_resource_handler_walker.py=utf-8 encoding/=UTF-8 diff --git a/src/pyrestresource/__init__.py b/src/pyrestresource/__init__.py index 15d39df..3722afd 100644 --- a/src/pyrestresource/__init__.py +++ b/src/pyrestresource/__init__.py @@ -54,3 +54,4 @@ from .rest_resource_plugin import ( ResourcePlugin_dict_default, ) from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule +from .rest_login import RestResourceBaseLogin, UserLogin diff --git a/src/pyrestresource/helpers.py b/src/pyrestresource/helpers.py index a39262c..a08067b 100644 --- a/src/pyrestresource/helpers.py +++ b/src/pyrestresource/helpers.py @@ -15,3 +15,17 @@ class _JSONEncoder(json.JSONEncoder): # if the obj is uuid, we simply return the value of uuid return str(o) return json.JSONEncoder.default(self, o) + + +def parse_dict_cookies(cookies: str) -> dict[str, str]: + result = {} + for item in cookies.split(";"): + item = item.strip() + if not item: + continue + if "=" not in item: + result[item] = None + continue + name, value = item.split("=", 1) + result[name] = value + return result diff --git a/src/pyrestresource/rest_ACL.py b/src/pyrestresource/rest_ACL.py index a75fa8b..3ebdbaa 100644 --- a/src/pyrestresource/rest_ACL.py +++ b/src/pyrestresource/rest_ACL.py @@ -22,10 +22,6 @@ class ACL_target_group(ACL_target): name: str -class ACL_target_group_Annonymous(ACL_target): - name: str = "__ANNONYMOUS__" - - class ACL_target_group_Any(ACL_target_group): name: str = "__ANY__" diff --git a/src/pyrestresource/rest_login.py b/src/pyrestresource/rest_login.py new file mode 100644 index 0000000..0fa0c2e --- /dev/null +++ b/src/pyrestresource/rest_login.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# pyrestresource(c) by chacha +# +# pyrestresource is licensed under a +# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. +# +# You should have received a copy of the license along with this +# work. If not, see . + +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring + +"""CLI interface module""" +from __future__ import annotations + +from typing import Optional, ClassVar, TYPE_CHECKING +from secrets import token_hex, compare_digest +from datetime import datetime + +from pydantic import BaseModel, Field + +from .rest_types import rsrc_verb +from .rest_resource import RestResourceBase + +from .rest_request import RestRequest, RestRequestParams_GET +from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule + +if TYPE_CHECKING or True: + from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default + + +class UserLogin(BaseModel): + username: str + secret: str + + +class UserSession(BaseModel): + last_update: datetime + user_login: UserLogin + host: Optional[str] + + +class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): + ar_UserLogin: list[UserLogin] = [] + + def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login: + print("hook GET") + print(resource) + print(params) + return resource + + def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login: + print("hook PUT") + # print(self.get_ar_userlogin()) + print(resource.username) + print(resource.secret) + + token = self.user_login(resource.username, resource.secret) + self.set_resp_cookie_value("Authorization", f"Bearer {token}") + + return resource + + +class Login(RestResourceBase): + username: Optional[str] = Field(None) + secret: Optional[str] = Field( + None, + exclude=True, + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + + +class RestResourceBaseLogin(RestResourceBase): + _ar_user_login: ClassVar[list[UserLogin]] = [] + _ar_user_session: dict[str, UserSession] = {} + _max_session_time_minutes: ClassVar[int] = 20 + login: Login = Field(default=Login(), plugin=ResourcePlugin_Login) + + def _process_request_session(self, request: RestRequest) -> None: + auth_cookie = request.get_cookie("Authorization") + if auth_cookie != None: + if auth_cookie in self._ar_user_session: + print("USER SESSION FOUND !") + print(self._ar_user_session[auth_cookie].user_login.username) + print(auth_cookie) + + time_diff_min = (datetime.now() - self._ar_user_session[auth_cookie].last_update).total_seconds() / 60 + + if time_diff_min > self._max_session_time_minutes: + del self._ar_user_session[auth_cookie] + raise RuntimeError("session timeout ! (session reseted)") + + request.set_user(self._ar_user_session[auth_cookie].user_login.username) + return + + print("Invalid session") + return + + print("non-connected user") + + def user_login(self, user_name: str, user_secret: str, request: RestRequest) -> str: + already_failed: bool = False + + for iter_user_login in self._ar_user_login: + username_ok: bool = compare_digest(user_name, iter_user_login.username) + secret_ok: bool = compare_digest(user_secret, iter_user_login.secret) + + if username_ok is True: + if secret_ok is True and not already_failed: + return self._register_user_session(iter_user_login, request) + else: + already_failed = True + else: + pass + pass + + if already_failed: + raise RuntimeError("Wrong auth") # TODO: specific exception + + def _register_user_session(self, user_login: UserLogin, request: RestRequest) -> str: + token = token_hex(16) + new_user_session = UserSession(last_update=datetime.now(), user_login=user_login, host=request.get_host()) + self._ar_user_session[f"Bearer {token}"] = new_user_session + return token diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index 4549d0e..c5d08cd 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -3,11 +3,14 @@ from __future__ import annotations from typing import ( + Any, Optional, Generic, ) from re import sub from urllib.parse import urlparse, parse_qs +from http.cookies import SimpleCookie + from pydantic import BaseModel, Field from typeguard import check_type @@ -26,7 +29,8 @@ from .rest_request_opt import ( _T_RestRequestParams_PUT, ) -from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group, ACL_target_group_Annonymous +from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group +from .helpers import parse_dict_cookies class RequestFactory( @@ -114,6 +118,8 @@ class RestRequest(Generic[_T_RestRequestParams]): self.url: str self.verb: rsrc_verb self.data: dict + self.raw_headers: list[Any] + self.headers: dict[str, None | str | dict[str, None | str]] = {"host": None, "cookie": {}} self._saved_url_params: dict self.ReqParams: _T_RestRequestParams = type_request_params() self.url_stack: list[str] @@ -122,7 +128,7 @@ class RestRequest(Generic[_T_RestRequestParams]): self.incoming_cookie: dict[str, str] = incoming_cookie self.outgoing_cookie: dict[str, str] = outgoing_cookie self.user: ACL_target_user = ACL_target_user_Annonymous() - self.group: ACL_target_group = ACL_target_group_Annonymous() + self.groups: list[ACL_target_group] = [] self.result: Optional[str] = None # = or create a fresh one = @@ -151,6 +157,24 @@ class RestRequest(Generic[_T_RestRequestParams]): self._saved_url_stack = self.url_stack.copy() self.url_stack_index = 0 + def set_headers(self, headers: list[Any]) -> None: + self.raw_headers = headers + for elem in self.raw_headers: + if elem[0] == b"host": + self.headers["host"] = elem[1].decode("utf-8") + # elif elem[0] == b"user-agent": + # self.headers["user-agent"] = elem[1].decode("utf-8") + elif elem[0] == b"cookie": + self.headers["cookie"] = parse_dict_cookies(elem[1].decode("utf-8")) + + def get_cookie(self, key: str) -> str | None: + if key not in self.headers["cookie"]: + return None + return self.headers["cookie"][key] + + def get_host(self) -> str: + print(self.headers["host"]) + def set_result(self, result: str): self.result = result @@ -160,8 +184,8 @@ class RestRequest(Generic[_T_RestRequestParams]): def set_user(self, user: ACL_target_user): self.user: ACL_target_user = user - def set_group(self, group: ACL_target_group): - self.group: ACL_target_group = group + def add_group(self, group: ACL_target_group): + self.groups.append(group) def update_ReqParams(self, type_request_params: type[_T_RestRequestParams]): self.ReqParams = type_request_params(**self._saved_url_params) diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index cfad675..75525f5 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -15,6 +15,7 @@ from __future__ import annotations from abc import ABC from typing import ( + Any, cast, ClassVar, get_args, @@ -39,7 +40,6 @@ from .rest_ACL import ( ACL_target_user, ACL_target_group, ACL_target_user_Annonymous, - ACL_target_group_Annonymous, ACL_target_group_Any, ACL_rule, ) @@ -219,8 +219,8 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_ ): if "plugin" in self.resource.json_schema_extra: plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] - if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase): - raise RuntimeError("Wrong plugin signature provided") + if not issubclass(plugin_resource, ResourcePlugin_RestResourceBase): + raise RuntimeError(f"Wrong plugin signature provided for {plugin_resource} : {type(plugin_resource)}") self.parent.annotation._plugins_[self.resource_name] = plugin_resource # print("ADD RESOURCE PLUGIN") @@ -246,7 +246,7 @@ def register_rest_rootpoint(klass: type[RestResourceBase]): class RestResourceBase(ABC, BaseModel, validate_assignment=True): - _resp_cookies: ClassVar[dict[str, str]] = dict() + # _resp_cookies: ClassVar[dict[str, str]] = {} _dict_key_type_: ClassVar[dict[str, T_T_DictKey]] = {} _dict_value_type_: ClassVar[dict[str, T_T_DictValues]] = {} _model_dump_excluded_: ClassVar[dict[str, bool]] = {} @@ -264,43 +264,45 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): ] ] = {} - def _check_acl(self, user: ACL_target_user, group: ACL_target_group, verb: rsrc_verb, field: str, is_self: bool = True): - # print(f"evaluate self ACLs rule: {self._ACL_record_}") + def _check_acl(self, user: str, groups: list[ACL_target_group], verb: rsrc_verb, field: str, is_self: bool = True): + print(f"evaluate self ACLs rule: {self._ACL_record_}") + print(f"user: {user}") + print(f"groups: {groups}") if is_self and verb is rsrc_verb.GET and self.model_fields[field].exclude is True: # print("ALLOWED (excluded field)") return for acl in self._ACL_record_[field]: - # print(f"evaluate ACL rule: {acl}") + print(f"evaluate ACL rule: {acl}") if verb in acl.verbs: if isinstance(acl.target, ACL_target_user): - if user == acl.target: + if user == acl.target.name: if acl.rule is ACL_rule.ALLOW: - # print("ALLOWED (user)") + print("ALLOWED (user)") return raise RuntimeError(f"Not allowed access detected: {field}") elif isinstance(acl.target, ACL_target_group): - if group == acl.target or acl.target == ACL_target_group_Any(): + if acl.target.name in groups or isinstance(acl.target, ACL_target_group_Any): if acl.rule is ACL_rule.ALLOW: - # print("ALLOWED (group)") + print("ALLOWED (group)") return raise RuntimeError(f"Not allowed access detected: {field}") else: raise RuntimeError(f"Wrong ACL target type: {field}") - # print("ALLOWED (Default)") + print("ALLOWED (Default)") def check_acl_field(self, request: RestRequest, req_index: int = 0) -> None: """Check ACL on requested field access""" - self._check_acl(request.user, request.group, request.get_verb(), request.get_resource_origin(req_index), False) + self._check_acl(request.user, request.groups, request.get_verb(), request.get_resource_origin(req_index), False) def check_acl_self(self, request: RestRequest, new_data: Optional[dict[str, _T_SupportedRESTFields]]) -> None: """Check ACL on requested field operation (involving checking sub-fields)""" if request.get_verb() is rsrc_verb.GET: for key in self.model_fields.keys(): - self._check_acl(request.user, request.group, rsrc_verb.GET, key) + self._check_acl(request.user, request.groups, rsrc_verb.GET, key) elif request.get_verb() is rsrc_verb.PUT: for key in new_data.keys(): if key in self.model_fields: - self._check_acl(request.user, request.group, rsrc_verb.PUT, key) + self._check_acl(request.user, request.groups, rsrc_verb.PUT, key) else: raise RuntimeError("Incompatible verb") @@ -324,21 +326,24 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): async def __call__(self, scope, receive, send): assert scope["type"] == "http" + method = scope["method"] + assert method in ["GET", "DELETE", "PUT", "POST"] + if b"content-type" in scope["headers"]: assert scope["headers"][b"content-type"] == b"application/json" - # import pprint + import pprint - # print("----REC HEADER ---") - # pprint.pprint(scope["headers"]) + print("----REC HEADER ---") + pprint.pprint(scope["headers"]) body = await self.read_body(receive) verb = rsrc_verb[scope["method"]] request: RestRequest = self.process_request( - scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8") + scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8"), scope["headers"] ) assert request != None @@ -373,12 +378,16 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): } ) + def _process_request_session(self, request: RestRequest) -> None: + pass + def process_request( self, url: str, verb: rsrc_verb = rsrc_verb.GET, data_json: Optional[str] = None, query_string: Optional[str] = None, + headers: Optional[list[Any]] = None, ) -> RestRequest: from .rest_resource_handler import ( ResourceHandler, @@ -389,11 +398,16 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if data_json: data = json.loads(data_json) + # creating the root handler ressource_handler: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string) + # preparing request & session request: RestRequest = ressource_handler.get_request() - assert request != None + if headers is not None: + request.set_headers(headers) + self._process_request_session(request) + # processing the verb result = ressource_handler.process_verb() # print("OOO") diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index 01ecc0b..5e99bd7 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -23,7 +23,6 @@ from .rest_ACL import ( ACL_target_user, ACL_target_group, ACL_target_user_Annonymous, - ACL_target_group_Annonymous, ACL_target_group_Any, ACL_rule, ) @@ -101,6 +100,7 @@ class ResourceHandler( self.next_handler: Optional[ResourceHandler] = None self.saved_url: list[str] = [] self.resource: _T_Resource = resource + self.root_resource: _T_Resource = resource if prev_handler is None else prev_handler.root_resource self.req: RestRequest if prev_handler is not None: self.prev_handler = prev_handler @@ -484,14 +484,6 @@ class ResourceHandler_RestResourceBase( if len(self.req.get_url_stack()) == 0: # destination reached if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET: raise RuntimeError(f"Not allowed READ access detected: {self.req.get_url_stack()}") - """ # not sure init_var has the expected behavior (read_only) - if self.resource.model_fields[self.req.get_resource_origin(0)].init_var is True and self.req.get_verb() in [ - rsrc_verb.POST, - rsrc_verb.PUT, - rsrc_verb.DELETE, - ]: - raise RuntimeError(f"Not allowed WRITE access detected: {self.req.get_url_stack()}") - """ def _handle_process_get(self, params) -> RestResourceBase: # print(f"{type(self).__name__}->_process_get()") @@ -504,11 +496,15 @@ class ResourceHandler_RestResourceBase( for key, attr in self.resource.model_fields.items(): if key in self.resource._plugins_: if isinstance(self.resource._plugins_[key], ResourcePlugin_field): - plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, self.resource._plugins_[key](self.req)) + plugin_field: ResourcePlugin_field = cast( + ResourcePlugin_field, self.resource._plugins_[key](self.req, self.root_resource) + ) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_field_get(value, params)) elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): - plugin_field: ResourcePlugin_field = cast(ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req)) + plugin_field: ResourcePlugin_field = cast( + ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource) + ) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_resource_get(value, params)) @@ -530,14 +526,14 @@ class ResourceHandler_RestResourceBase( if isinstance(self.resource._plugins_[key], ResourcePlugin_field): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.resource._plugins_[key](self.req), + self.resource._plugins_[key](self.req, self.root_resource), ) value = plugin_rsrc.handle_field_get(value, params) elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.resource._plugins_[key](self.req), + self.resource._plugins_[key](self.req, self.root_resource), ) value = plugin_rsrc.handle_resource_get(value, params) @@ -559,7 +555,9 @@ class ResourceHandler_RestResourceBase( for key, attr in _new_resrc.model_fields.items(): if key in _new_resrc._plugins_: if isinstance(_new_resrc._plugins_[key], ResourcePlugin_field): - plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, _new_resrc._plugins_[key](self.req)) + plugin_field: ResourcePlugin_field = cast( + ResourcePlugin_field, _new_resrc._plugins_[key](self.req, self.root_resource) + ) value = getattr(_new_resrc, key) setattr(_new_resrc, key, plugin_field.handle_field_put(value, params)) @@ -574,7 +572,7 @@ class ResourceHandler_RestResourceBase( if key in self.prev_handler.prev_handler.resource._plugins_: plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.prev_handler.prev_handler.resource._plugins_[key](self.req), + self.prev_handler.prev_handler.resource._plugins_[key](self.req, self.root_resource), ) _new_resrc = plugin_rsrc.handle_dict_elem_put(_new_resrc, params) # element is within a RestResourceBase @@ -583,7 +581,7 @@ class ResourceHandler_RestResourceBase( if key in self.prev_handler.resource._plugins_: plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.prev_handler.resource._plugins_[key](self.req), + self.prev_handler.resource._plugins_[key](self.req, self.root_resource), ) _new_resrc = plugin_rsrc.handle_resource_put(_new_resrc, params) @@ -634,7 +632,7 @@ class ResourceHandler_simple( if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, - self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req), + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource), ) return plugin_simple.handle_field_get(self.resource, params) @@ -655,7 +653,7 @@ class ResourceHandler_simple( # print("PLUGIN FOUND") plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, - self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req), + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource), ) # print(value) value = plugin_simple.handle_field_put(value, params) diff --git a/src/pyrestresource/rest_resource_plugin.py b/src/pyrestresource/rest_resource_plugin.py index 9308aa6..72290f7 100644 --- a/src/pyrestresource/rest_resource_plugin.py +++ b/src/pyrestresource/rest_resource_plugin.py @@ -1,7 +1,7 @@ from __future__ import annotations -from typing import Optional, Protocol, runtime_checkable, TYPE_CHECKING -from abc import abstractmethod +from typing import Optional, Generic, TYPE_CHECKING +from abc import abstractmethod, ABC from .rest_types import ( _T_DictValues, @@ -12,6 +12,7 @@ from .rest_types import ( from .rest_request import RestRequest + if TYPE_CHECKING or True: from .rest_request_opt import ( RestRequestParams_GET, @@ -26,21 +27,33 @@ if TYPE_CHECKING or True: ) -class ResourcePlugin(Protocol): - def __init__(self, request: RestRequest) -> None: - self.request: RestRequest = request +class ResourcePlugin(ABC): + def __init__(self, request: RestRequest, root_resource: "RestResourceBase") -> None: + self.__request: RestRequest = request + self.__root_resource: RestRequest = root_resource - def set_resp_cookie(self, name: str, value: str): + def user_login(self, user_name: str, user_secret: str) -> str: + return self.__root_resource.user_login(user_name, user_secret, self.__request) + + """ + def get_ar_userlogin(self): + print("===========") + return self.__root_resource.get_ar_user_login() + """ + + def getr_req_cookie_value(self, key: str) -> Optional[str]: + return self.__request.incoming_cookie[key] + + def set_resp_cookie_value(self, key: str, value: str): # print("AAA") # print(name) # print(value) # print(self.cookies) # print(type(self.cookies)) - self.request.outgoing_cookie[name] = value + self.__request.outgoing_cookie[key] = value -@runtime_checkable -class ResourcePlugin_field(ResourcePlugin, Protocol[TV_SupportedRESTFields]): +class ResourcePlugin_field(ResourcePlugin, Generic[TV_SupportedRESTFields]): @abstractmethod def handle_field_get(self, resource: TV_SupportedRESTFields, params: RestRequestParams_GET) -> TV_SupportedRESTFields: ... @@ -60,8 +73,7 @@ class ResourcePlugin_field_default(ResourcePlugin_field[TV_SupportedRESTFields]) return resource -@runtime_checkable -class ResourcePlugin_RestResourceBase(ResourcePlugin, Protocol[TV_RestResourceBase]): +class ResourcePlugin_RestResourceBase(ResourcePlugin, Generic[TV_RestResourceBase]): @abstractmethod def handle_resource_get( self, @@ -97,8 +109,7 @@ class ResourcePlugin_RestResourceBase_default(ResourcePlugin_RestResourceBase[TV return resource -@runtime_checkable -class ResourcePlugin_dict(ResourcePlugin, Protocol[_T_DictKey, _T_DictValues]): +class ResourcePlugin_dict(ResourcePlugin, Generic[_T_DictKey, _T_DictValues]): @abstractmethod def handle_dict_get_keys( self, diff --git a/test/test_ACL.py b/test/test_ACL.py index 17859e3..5e69df8 100644 --- a/test/test_ACL.py +++ b/test/test_ACL.py @@ -59,7 +59,7 @@ def init_classes(): resource_with_secret_ACL: TestResource = Field( default=TestResource(), ACL=[ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY)] ) - resource2: TestResource2 = Field(TestResource2()) + resource_ro: TestResource2 = Field(TestResource2()) # this add the classes to globals to allow using them later on # => this is only for uinit-testing purpose and is not needed in real use @@ -77,21 +77,23 @@ class Test_RestAPI_ACL(unittest.TestCase): result = self.testapp.process_request("/", rsrc_verb.GET) self.assertEqual(result.get_result(), "{}") - result = self.testapp.process_request("/resource2", rsrc_verb.GET) + result = self.testapp.process_request("/resource_ro", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "3.2.1"}') - self.testapp.process_request("/resource2/version", rsrc_verb.PUT, '"6.6.6"') + self.testapp.process_request("/resource_ro/version", rsrc_verb.PUT, '"6.6.6"') - result = self.testapp.process_request("/resource2", rsrc_verb.GET) + result = self.testapp.process_request("/resource_ro", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') with self.assertRaises(RuntimeError): # TODO: custom exception - self.testapp.process_request("/resource2/version_ro", rsrc_verb.PUT, '"6.6.6"') + self.testapp.process_request("/resource_ro/version_ro", rsrc_verb.PUT, '"6.6.6"') + self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3") with self.assertRaises(RuntimeError): # TODO: custom exception - self.testapp.process_request("/resource2", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}') + self.testapp.process_request("/resource_ro", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}') + self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3") - result = self.testapp.process_request("/resource2", rsrc_verb.GET) + result = self.testapp.process_request("/resource_ro", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') def test_subresource(self): diff --git a/test/test_rest_login.py b/test/test_rest_login.py index ff71667..96e451a 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -3,7 +3,7 @@ import unittest from unittest.mock import patch from os import chdir from pathlib import Path -from typing import Optional, Annotated +from typing import Optional, Annotated, ClassVar from pydantic import Field from uuid import UUID, uuid4 from time import time, sleep @@ -12,16 +12,17 @@ import socket import requests from contextlib import closing from multiprocessing import Process -from secrets import token_hex print(__name__) print(__package__) -from pydantic import BaseModel from src.pyrestresource import ( - register_rest_rootpoint, + ACL_target_user, + UserLogin, RestResourceBase, + RestResourceBaseLogin, + register_rest_rootpoint, rsrc_verb, RestRequestParams_GET, RestRequestParams_POST, @@ -42,58 +43,25 @@ chdir(testdir_path.parent.resolve()) # to allow mock-ing, all the tested classes are in a function def init_classes(): - class UserLogin(BaseModel): - username: str - secret: str - token: Optional[str] = None + user_CHACHA = UserLogin(username="chacha", secret="123456") - class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): - ar_UserLogin: list[UserLogin] = [UserLogin(username="chacha", secret="123456")] - - def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login: - print("hook GET") - print(resource) - print(params) - return resource - - def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login: - print("hook PUT") - - print(resource.username) - print(resource.secret) - - for _UserLogin in self.ar_UserLogin: - if _UserLogin.username == resource.username and _UserLogin.secret == resource.secret: - print("user connected") - _UserLogin.token = token_hex(16) - self.set_resp_cookie("test", _UserLogin.token) - print(f"generated token: {_UserLogin.token}") - return resource - print("login NOT found") - # print(resource) - # print(resource.username) - # print(resource.secret) - # print(params) - return resource - - class Login(RestResourceBase): - username: Optional[str] = Field(None) - secret: Optional[str] = Field( - None, - exclude=True, + class TestResourceACL(RestResourceBase): + test_field: Optional[str] = Field( + "ORIGIN_VALUE", ACL=[ - ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), - ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user(name="chacha"), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), ], ) @register_rest_rootpoint - class RootApp(RestResourceBase): - login: Login = Field(default=Login(), plugin=ResourcePlugin_Login) + class RootApp(RestResourceBaseLogin): + _ar_user_login: ClassVar[list[UserLogin]] = [user_CHACHA] + test_resource: TestResourceACL = TestResourceACL() # this add the classes to globals to allow using them later on # => this is only for uinit-testing purpose and is not needed in real use - globals()[Login.__name__] = Login + globals()[TestResourceACL.__name__] = TestResourceACL globals()[RootApp.__name__] = RootApp @@ -116,6 +84,61 @@ class Test_RestAPI_LOGIN(unittest.TestCase): init_classes() self.testapp = RootApp() + def test_access(self): + ip, port = find_free_port() + print(f"ip1={ip}") + print(f"port1={port}") + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + try: + # before modification read + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json='"TEST SET VALUE"') + self.assertEqual(response.status_code, 500) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "chacha", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # authenticated write + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + finally: + proc.terminate() + s.close() + def test_login(self): result = self.testapp.process_request("/login", rsrc_verb.GET) print("*****************") @@ -172,6 +195,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): json={"username": "chacha", "secret": "123456"}, ) print(response) + print("??????") print(response.headers) self.assertEqual(response.status_code, 201) -- 2.47.3 From cffa209c9ab65b69e4fbde9efda947cc17690a9f Mon Sep 17 00:00:00 2001 From: cclecle Date: Sun, 5 Nov 2023 17:44:15 +0000 Subject: [PATCH 09/20] finish 1st login version + clean code + fix unittest regressions --- src/pyrestresource/__init__.py | 9 +- src/pyrestresource/helpers.py | 1 + src/pyrestresource/rest_ACL.py | 8 + src/pyrestresource/rest_login.py | 25 +- src/pyrestresource/rest_request.py | 19 +- src/pyrestresource/rest_request_opt.py | 5 +- src/pyrestresource/rest_resource.py | 247 ++--------------- src/pyrestresource/rest_resource_handler.py | 43 +-- .../rest_resource_handler_walker.py | 12 +- src/pyrestresource/rest_resource_plugin.py | 25 +- src/pyrestresource/rest_resource_rootpoint.py | 169 ++++++++++++ src/pyrestresource/rest_resource_walker.py | 45 ++-- src/pyrestresource/rest_types.py | 12 +- test/test_rest_login.py | 251 ++++++++++++------ test/test_rest_resource_plugins.py | 44 +-- test/test_rest_webserver.py | 11 +- 16 files changed, 463 insertions(+), 463 deletions(-) create mode 100644 src/pyrestresource/rest_resource_rootpoint.py diff --git a/src/pyrestresource/__init__.py b/src/pyrestresource/__init__.py index 3722afd..d129cc3 100644 --- a/src/pyrestresource/__init__.py +++ b/src/pyrestresource/__init__.py @@ -19,10 +19,8 @@ from typing import TYPE_CHECKING from .__metadata__ import __version__, __Summuary__, __Name__ -from .rest_resource import ( - register_rest_rootpoint, - RestResourceBase, -) +from .rest_resource import RestResourceBase +from .rest_resource_rootpoint import register_rest_rootpoint from .rest_types import rsrc_verb, T_SupportedRESTFields @@ -34,6 +32,7 @@ if TYPE_CHECKING: T_T_DictKey, T_DictValues, T_T_DictValues, + RestResourceException, ) from .rest_request_opt import ( @@ -52,6 +51,8 @@ from .rest_resource_plugin import ( ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, ResourcePlugin_dict_default, + RestResourcePluginException, + RestResourcePluginException_InvalidPluginSignature, ) from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule from .rest_login import RestResourceBaseLogin, UserLogin diff --git a/src/pyrestresource/helpers.py b/src/pyrestresource/helpers.py index a08067b..e8c13d1 100644 --- a/src/pyrestresource/helpers.py +++ b/src/pyrestresource/helpers.py @@ -1,6 +1,7 @@ # pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring from __future__ import annotations + from uuid import UUID import json diff --git a/src/pyrestresource/rest_ACL.py b/src/pyrestresource/rest_ACL.py index 3ebdbaa..e7158e4 100644 --- a/src/pyrestresource/rest_ACL.py +++ b/src/pyrestresource/rest_ACL.py @@ -1,10 +1,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING from pydantic import BaseModel from enum import Enum, auto from .rest_types import rsrc_verb +if TYPE_CHECKING is True: + from .rest_login import UserLogin + class ACL_target(BaseModel): pass @@ -13,6 +17,10 @@ class ACL_target(BaseModel): class ACL_target_user(ACL_target): name: str + @classmethod + def from_user_login(cls, user_login: UserLogin) -> ACL_target_user: + return cls(name=user_login.username) + class ACL_target_user_Annonymous(ACL_target): name: str = "__ANNONYMOUS__" diff --git a/src/pyrestresource/rest_login.py b/src/pyrestresource/rest_login.py index 0fa0c2e..707e421 100644 --- a/src/pyrestresource/rest_login.py +++ b/src/pyrestresource/rest_login.py @@ -12,21 +12,19 @@ """CLI interface module""" from __future__ import annotations - from typing import Optional, ClassVar, TYPE_CHECKING + from secrets import token_hex, compare_digest from datetime import datetime - from pydantic import BaseModel, Field from .rest_types import rsrc_verb from .rest_resource import RestResourceBase +from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule, ACL_target_user +from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default -from .rest_request import RestRequest, RestRequestParams_GET -from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule - -if TYPE_CHECKING or True: - from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default +if TYPE_CHECKING is True: + from .rest_request import RestRequest, RestRequestParams_GET class UserLogin(BaseModel): @@ -44,20 +42,11 @@ class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): ar_UserLogin: list[UserLogin] = [] def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login: - print("hook GET") - print(resource) - print(params) - return resource + return Login(username=self.get_user_login()) def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login: - print("hook PUT") - # print(self.get_ar_userlogin()) - print(resource.username) - print(resource.secret) - token = self.user_login(resource.username, resource.secret) self.set_resp_cookie_value("Authorization", f"Bearer {token}") - return resource @@ -93,7 +82,7 @@ class RestResourceBaseLogin(RestResourceBase): del self._ar_user_session[auth_cookie] raise RuntimeError("session timeout ! (session reseted)") - request.set_user(self._ar_user_session[auth_cookie].user_login.username) + request.set_user(ACL_target_user(name=self._ar_user_session[auth_cookie].user_login.username)) return print("Invalid session") diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index c5d08cd..0d37b7d 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -4,19 +4,16 @@ from __future__ import annotations from typing import ( Any, - Optional, Generic, + TYPE_CHECKING, ) + from re import sub from urllib.parse import urlparse, parse_qs -from http.cookies import SimpleCookie - from pydantic import BaseModel, Field - from typeguard import check_type -from .rest_types import rsrc_verb, T_SupportedRESTFields, T_AllSupportedFields - +from .rest_types import rsrc_verb, T_AllSupportedFields from .rest_request_opt import ( RestRequestParams_POST, RestRequestParams_DELETE, @@ -28,10 +25,13 @@ from .rest_request_opt import ( _T_RestRequestParams_GET, _T_RestRequestParams_PUT, ) - from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group from .helpers import parse_dict_cookies +if TYPE_CHECKING is True: + from typing import Optional + from .rest_types import T_SupportedRESTFields + class RequestFactory( Generic[ @@ -182,7 +182,10 @@ class RestRequest(Generic[_T_RestRequestParams]): return self.result def set_user(self, user: ACL_target_user): - self.user: ACL_target_user = user + self.user = user + + def get_user(self): + return self.user def add_group(self, group: ACL_target_group): self.groups.append(group) diff --git a/src/pyrestresource/rest_request_opt.py b/src/pyrestresource/rest_request_opt.py index 10cc360..599bdbe 100644 --- a/src/pyrestresource/rest_request_opt.py +++ b/src/pyrestresource/rest_request_opt.py @@ -1,14 +1,17 @@ # pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring from __future__ import annotations +from typing import Generic, Optional, TypeVar, TYPE_CHECKING -from typing import Optional, Generic, TypeVar from pydantic import BaseModel, Extra from .rest_types import ( _T_DictKey, ) +if TYPE_CHECKING is True: + pass + class RestRequestParams(BaseModel, extra=Extra.allow): pass diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index 75525f5..bd42c32 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -1,250 +1,35 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# pyrestresource(c) by chacha -# -# pyrestresource is licensed under a -# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. -# -# You should have received a copy of the license along with this -# work. If not, see . - -# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring - -"""CLI interface module""" from __future__ import annotations - -from abc import ABC from typing import ( Any, - cast, ClassVar, - get_args, - get_origin, Optional, TYPE_CHECKING, ) + +from abc import ABC import json -from pydantic.fields import FieldInfo, Field from pydantic import BaseModel from .helpers import _JSONEncoder from .rest_types import rsrc_verb, _T_SupportedRESTFields -from .rest_resource_plugin import ( - ResourcePlugin_field, - ResourcePlugin_RestResourceBase, - ResourcePlugin_dict, -) from .rest_ACL import ( ACL_record, ACL_target_user, ACL_target_group, - ACL_target_user_Annonymous, ACL_target_group_Any, ACL_rule, ) - -from .rest_resource_walker import ( - RestResourceWalkerFutureResult, - RestResourceWalker_Root, - RestResourceWalker_Sub_T_Dict, - RestResourceWalker_Sub_RestFields, - RestResourceWalker_Sub_RestResourceBase, -) - from .rest_request import RestRequest -if TYPE_CHECKING: +if TYPE_CHECKING is True: from .rest_types import ( - T_ListIndex, - T_ListSize, - T_DictKey, T_T_DictKey, - T_DictValues, T_T_DictValues, - T_SupportedRESTFields, ) -class RestResourceWalkerFutureResult_RestResourceBase_tree_exclude(RestResourceWalkerFutureResult[dict]): - def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: - res = {} - res[self.source.resource_name] = dict() - for subres in result: - key = next(iter(subres)) - if ( - key in self.source.annotation._model_dump_excluded_ # pylint: disable=protected-access - and self.source.annotation._model_dump_excluded_[key] is True # pylint: disable=protected-access - ): - res[self.source.resource_name] = res[self.source.resource_name] | {key: True} - else: - res[self.source.resource_name] = res[self.source.resource_name] | subres - return res - - -class RestResourceWalkerFutureResult_Dict_tree_exclude(RestResourceWalkerFutureResult[dict]): - def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: - res = {} - for subres in result: - res = res | subres - return res - - -class RestResourceWalker_Sub_T_Dict__tree_exclude(RestResourceWalker_Sub_T_Dict): - cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_Dict_tree_exclude - - -class RestResourceWalker_Sub_RestResourceBase__tree_exclude(RestResourceWalker_Sub_RestResourceBase): - cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_RestResourceBase_tree_exclude - - -class RestResourceWalker_Root__tree_exclude(RestResourceWalker_Root): - cls_RestResourceWalker_Sub = [ - RestResourceWalker_Sub_T_Dict__tree_exclude, - RestResourceWalker_Sub_RestFields, - RestResourceWalker_Sub_RestResourceBase__tree_exclude, - ] - - -class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): - def process(self) -> None: - datatype = get_args(self.annotation) - - # checking compatibility - if not get_origin(datatype[1]) is None: - raise RuntimeError("complex dict types are not supported (should create a RestResourceBase container)") - if not datatype[0] in _T_SupportedRESTFields: - raise RuntimeError(f"Unsupported Dict Field value type in class (key)") - - # preprocessing types / structure - if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): - self.parent.annotation._dict_key_type_[self.resource_name] = datatype[0] # pylint: disable=protected-access - self.parent.annotation._dict_value_type_[self.resource_name] = datatype[1] # pylint: disable=protected-access - self.parent.annotation._model_dump_excluded_[self.resource_name] = True # pylint: disable=protected-access - - self.resource.exclude = True - self.parent.resource.model_rebuild(force=True) - - self.parent.annotation._ACL_record_[self.resource_name] = [] - - if ( - isinstance(self.resource, FieldInfo) - and self.resource.json_schema_extra is not None - and type(self.resource.json_schema_extra) is dict - ): - if "plugin" in self.resource.json_schema_extra: - plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"] - if not isinstance(plugin_dict, ResourcePlugin_dict): - raise RuntimeError("Wrong plugin signature provided") - self.parent.annotation._plugins_[self.resource_name] = plugin_dict - # print("ADD DICT PLUGIN") - - if "ACL" in self.resource.json_schema_extra: - if isinstance(self.resource.json_schema_extra["ACL"], list): - # print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}") - self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] - else: - raise RuntimeError("ACL must be a list()") - - else: - raise RuntimeError("dict must be contained in a RestResourceBase") - - -class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFields): - def process(self) -> None: - if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): - import pprint - - # print("1aaaaaaaaaa") - # pprint.pprint(self.resource.json_schema_extra) - # pprint.pprint(self.annotation) - # pprint.pprint(self.resource.exclude) - - self.parent.annotation._ACL_record_[self.resource_name] = [] - - if ( - isinstance(self.resource, FieldInfo) - and self.resource.json_schema_extra is not None - and type(self.resource.json_schema_extra) is dict - ): - # print("aaaaaaaaaa") - - if "primary_key" in self.resource.json_schema_extra and self.resource.json_schema_extra["primary_key"] is True: - if self.parent.annotation._primary_key_ is not None: - raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}") - self.parent.annotation._primary_key_ = self.resource_name - self.parent.annotation._ACL_record_[self.resource_name] = [ - ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY) - ] - - if "plugin" in self.resource.json_schema_extra: - plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"] - if not isinstance(plugin_field, ResourcePlugin_field): - raise RuntimeError("Wrong plugin signature provided") - self.parent.annotation._plugins_[self.resource_name] = plugin_field - # print("ADD FIELD PLUGIN") - - if "ACL" in self.resource.json_schema_extra: - if isinstance(self.resource.json_schema_extra["ACL"], list): - # print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}") - self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] - else: - raise RuntimeError("ACL must be a list()") - - else: - raise RuntimeError("fields must be contained in a RestResourceBase") - - -class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_RestResourceBase): - def process(self) -> None: - setattr(self.annotation, "_dict_key_type_", {}) - setattr(self.annotation, "_dict_value_type_", {}) - setattr(self.annotation, "_model_dump_excluded_", {}) - setattr(self.annotation, "_primary_key_", None) - setattr(self.annotation, "_plugins_", {}) - setattr(self.annotation, "_ACL_record_", {}) - - # preprocessing types / structure - if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): - self.parent.annotation._model_dump_excluded_[self.resource_name] = True - self.resource.exclude = True - self.parent.resource.model_rebuild(force=True) - self.parent.annotation._ACL_record_[self.resource_name] = [] - - if ( - isinstance(self.resource, FieldInfo) - and self.resource.json_schema_extra is not None - and type(self.resource.json_schema_extra) is dict - ): - if "plugin" in self.resource.json_schema_extra: - plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] - if not issubclass(plugin_resource, ResourcePlugin_RestResourceBase): - raise RuntimeError(f"Wrong plugin signature provided for {plugin_resource} : {type(plugin_resource)}") - self.parent.annotation._plugins_[self.resource_name] = plugin_resource - # print("ADD RESOURCE PLUGIN") - - if "ACL" in self.resource.json_schema_extra: - if isinstance(self.resource.json_schema_extra["ACL"], list): - # print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}") - self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] - else: - raise RuntimeError("ACL must be a list()") - - -class RestResourceWalker_Root__tree_init(RestResourceWalker_Root): - cls_RestResourceWalker_Sub = [ - RestResourceWalker_Sub_T_Dict__tree_init, - RestResourceWalker_Sub_RestFields__tree_init, - RestResourceWalker_Sub_RestResourceBase__tree_init, - ] - - -def register_rest_rootpoint(klass: type[RestResourceBase]): - RestResourceWalker_Root__tree_init(klass).process() - return klass - - class RestResourceBase(ABC, BaseModel, validate_assignment=True): # _resp_cookies: ClassVar[dict[str, str]] = {} _dict_key_type_: ClassVar[dict[str, T_T_DictKey]] = {} @@ -264,31 +49,31 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): ] ] = {} - def _check_acl(self, user: str, groups: list[ACL_target_group], verb: rsrc_verb, field: str, is_self: bool = True): - print(f"evaluate self ACLs rule: {self._ACL_record_}") - print(f"user: {user}") - print(f"groups: {groups}") + def _check_acl(self, user: ACL_target_user, groups: list[ACL_target_group], verb: rsrc_verb, field: str, is_self: bool = True): + # print(f"evaluate self ACLs rule: {self._ACL_record_}") + # print(f"user: {user}") + # print(f"groups: {groups}") if is_self and verb is rsrc_verb.GET and self.model_fields[field].exclude is True: # print("ALLOWED (excluded field)") return for acl in self._ACL_record_[field]: - print(f"evaluate ACL rule: {acl}") + # print(f"evaluate ACL rule: {acl}") if verb in acl.verbs: if isinstance(acl.target, ACL_target_user): - if user == acl.target.name: + if user == acl.target: if acl.rule is ACL_rule.ALLOW: - print("ALLOWED (user)") + # print("ALLOWED (user)") return raise RuntimeError(f"Not allowed access detected: {field}") elif isinstance(acl.target, ACL_target_group): - if acl.target.name in groups or isinstance(acl.target, ACL_target_group_Any): + if isinstance(acl.target, ACL_target_group_Any) or any(_ for _ in groups if _.name == acl.target.name): if acl.rule is ACL_rule.ALLOW: - print("ALLOWED (group)") + # print("ALLOWED (group)") return raise RuntimeError(f"Not allowed access detected: {field}") else: raise RuntimeError(f"Wrong ACL target type: {field}") - print("ALLOWED (Default)") + # print("ALLOWED (Default)") def check_acl_field(self, request: RestRequest, req_index: int = 0) -> None: """Check ACL on requested field access""" @@ -334,10 +119,10 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if b"content-type" in scope["headers"]: assert scope["headers"][b"content-type"] == b"application/json" - import pprint + # import pprint - print("----REC HEADER ---") - pprint.pprint(scope["headers"]) + # print("----REC HEADER ---") + # pprint.pprint(scope["headers"]) body = await self.read_body(receive) verb = rsrc_verb[scope["method"]] diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index 5e99bd7..77cfd42 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -1,32 +1,22 @@ from __future__ import annotations -import abc from typing import Optional, cast, TypeVar, Generic, Self, TYPE_CHECKING +import abc + from .rest_types import ( rsrc_verb, T_SupportedRESTFields, T_DictKey, _T_SupportedRESTFields, T_Dict, - T_T_DictValues, T_DictValues, ) from .rest_resource import RestResourceBase -from .rest_request import RequestFactory, RestRequest - +from .rest_request import RequestFactory from .rest_resource_plugin import ( ResourcePlugin_field, ResourcePlugin_RestResourceBase, ) - -from .rest_ACL import ( - ACL_target_user, - ACL_target_group, - ACL_target_user_Annonymous, - ACL_target_group_Any, - ACL_rule, -) - from .rest_request_opt import ( RestRequestParams_POST, RestRequestParams_DELETE, @@ -43,16 +33,9 @@ from .rest_request_opt import ( _T_RestRequestParams_PUT, ) -from .rest_resource_handler_walker import RestResourceWalker_Root__handler - -if TYPE_CHECKING: - from .rest_types import ( - T_ListIndex, - T_ListSize, - T_T_DictKey, - T_FieldValue, - ) - +if TYPE_CHECKING is True: + from .rest_types import T_T_DictKey, T_T_DictValues + from .rest_request import RestRequest _T_Resource = TypeVar("_T_Resource", T_DictValues, T_Dict, T_SupportedRESTFields, RestResourceBase) @@ -311,7 +294,7 @@ class ResourceHandler_dict( dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] - _dict: dict[T_DictKey, "T_DictValues"] = cast(dict[T_DictKey, "T_DictValues"], self.resource) + _dict: dict[T_DictKey, T_DictValues] = cast(dict[T_DictKey, T_DictValues], self.resource) if params.API_key is not None: del _dict[dict_key_type(params.API_key)] @@ -335,7 +318,7 @@ class ResourceHandler_dict( _obj = dict_value_type(**self.req.get_data()) - _dict: dict[T_DictKey, "T_DictValues"] = cast(dict[T_DictKey, "T_DictValues"], self.resource) + _dict: dict[T_DictKey, T_DictValues] = cast(dict[T_DictKey, T_DictValues], self.resource) # 1st try/ using request param provided dict API_key if params.API_key is not None: @@ -495,13 +478,13 @@ class ResourceHandler_RestResourceBase( self.resource.check_acl_self(self.req, None) for key, attr in self.resource.model_fields.items(): if key in self.resource._plugins_: - if isinstance(self.resource._plugins_[key], ResourcePlugin_field): + if issubclass(self.resource._plugins_[key], ResourcePlugin_field): plugin_field: ResourcePlugin_field = cast( ResourcePlugin_field, self.resource._plugins_[key](self.req, self.root_resource) ) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_field_get(value, params)) - elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): + elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): plugin_field: ResourcePlugin_field = cast( ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource) ) @@ -523,14 +506,14 @@ class ResourceHandler_RestResourceBase( key = self.req.get_resource_origin(0) if key in self.resource._plugins_: - if isinstance(self.resource._plugins_[key], ResourcePlugin_field): + if issubclass(self.resource._plugins_[key], ResourcePlugin_field): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource), ) value = plugin_rsrc.handle_field_get(value, params) - elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): + elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource), @@ -554,7 +537,7 @@ class ResourceHandler_RestResourceBase( if isinstance(_new_resrc, RestResourceBase): for key, attr in _new_resrc.model_fields.items(): if key in _new_resrc._plugins_: - if isinstance(_new_resrc._plugins_[key], ResourcePlugin_field): + if issubclass(_new_resrc._plugins_[key], ResourcePlugin_field): plugin_field: ResourcePlugin_field = cast( ResourcePlugin_field, _new_resrc._plugins_[key](self.req, self.root_resource) ) diff --git a/src/pyrestresource/rest_resource_handler_walker.py b/src/pyrestresource/rest_resource_handler_walker.py index e23ec6f..c0f06ac 100644 --- a/src/pyrestresource/rest_resource_handler_walker.py +++ b/src/pyrestresource/rest_resource_handler_walker.py @@ -12,14 +12,7 @@ """CLI interface module""" from __future__ import annotations - -from typing import ( - ClassVar, - get_args, - get_origin, - Optional, - TYPE_CHECKING, -) +from typing import TYPE_CHECKING from .rest_resource_walker import ( RestResourceWalkerFutureResult, @@ -29,6 +22,9 @@ from .rest_resource_walker import ( RestResourceWalker_Sub_RestResourceBase, ) +if TYPE_CHECKING is True: + from typing import Optional + class RestResourceWalkerFutureResult_RestResourceBase_handler(RestResourceWalkerFutureResult[dict]): def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: diff --git a/src/pyrestresource/rest_resource_plugin.py b/src/pyrestresource/rest_resource_plugin.py index 72290f7..4e7bcea 100644 --- a/src/pyrestresource/rest_resource_plugin.py +++ b/src/pyrestresource/rest_resource_plugin.py @@ -1,6 +1,6 @@ from __future__ import annotations - from typing import Optional, Generic, TYPE_CHECKING + from abc import abstractmethod, ABC from .rest_types import ( @@ -8,12 +8,12 @@ from .rest_types import ( _T_DictKey, TV_SupportedRESTFields, TV_RestResourceBase, + RestResourceException, ) - from .rest_request import RestRequest - -if TYPE_CHECKING or True: +if TYPE_CHECKING is True: + from .rest_resource import RestResourceBase from .rest_request_opt import ( RestRequestParams_GET, RestRequestParams_PUT, @@ -27,19 +27,24 @@ if TYPE_CHECKING or True: ) +class RestResourcePluginException(RestResourceException): + pass + + +class RestResourcePluginException_InvalidPluginSignature(RestResourcePluginException): + pass + + class ResourcePlugin(ABC): - def __init__(self, request: RestRequest, root_resource: "RestResourceBase") -> None: + def __init__(self, request: RestRequest, root_resource: RestResourceBase) -> None: self.__request: RestRequest = request self.__root_resource: RestRequest = root_resource def user_login(self, user_name: str, user_secret: str) -> str: return self.__root_resource.user_login(user_name, user_secret, self.__request) - """ - def get_ar_userlogin(self): - print("===========") - return self.__root_resource.get_ar_user_login() - """ + def get_user_login(self) -> str: + return self.__request.get_user().name def getr_req_cookie_value(self, key: str) -> Optional[str]: return self.__request.incoming_cookie[key] diff --git a/src/pyrestresource/rest_resource_rootpoint.py b/src/pyrestresource/rest_resource_rootpoint.py new file mode 100644 index 0000000..a13e521 --- /dev/null +++ b/src/pyrestresource/rest_resource_rootpoint.py @@ -0,0 +1,169 @@ +from __future__ import annotations +from typing import ( + get_args, + get_origin, + TYPE_CHECKING, +) + +from pydantic.fields import FieldInfo + +from .rest_resource import RestResourceBase +from .rest_resource_plugin import ( + ResourcePlugin_field, + ResourcePlugin_RestResourceBase, + ResourcePlugin_dict, + RestResourcePluginException_InvalidPluginSignature, +) +from .rest_resource_walker import ( + RestResourceWalker_Root, + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, +) +from .rest_types import rsrc_verb, _T_SupportedRESTFields +from .rest_ACL import ( + ACL_record, + ACL_target_group_Any, + ACL_rule, +) + +if TYPE_CHECKING is True: + pass + + +class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): + def process(self) -> None: + datatype = get_args(self.annotation) + + # checking compatibility + if not get_origin(datatype[1]) is None: + raise RuntimeError("complex dict types are not supported (should create a RestResourceBase container)") + if not datatype[0] in _T_SupportedRESTFields: + raise RuntimeError(f"Unsupported Dict Field value type in class (key)") + + # preprocessing types / structure + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + self.parent.annotation._dict_key_type_[self.resource_name] = datatype[0] # pylint: disable=protected-access + self.parent.annotation._dict_value_type_[self.resource_name] = datatype[1] # pylint: disable=protected-access + self.parent.annotation._model_dump_excluded_[self.resource_name] = True # pylint: disable=protected-access + + self.resource.exclude = True + self.parent.resource.model_rebuild(force=True) + + self.parent.annotation._ACL_record_[self.resource_name] = [] + + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + ): + if "plugin" in self.resource.json_schema_extra: + plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"] + if not issubclass(plugin_dict, ResourcePlugin_dict): + raise RestResourcePluginException_InvalidPluginSignature() + self.parent.annotation._plugins_[self.resource_name] = plugin_dict + # print("ADD DICT PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + # print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") + + else: + raise RuntimeError("dict must be contained in a RestResourceBase") + + +class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFields): + def process(self) -> None: + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + # import pprint + + # print("1aaaaaaaaaa") + # pprint.pprint(self.resource.json_schema_extra) + # pprint.pprint(self.annotation) + # pprint.pprint(self.resource.exclude) + + self.parent.annotation._ACL_record_[self.resource_name] = [] + + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + ): + # print("aaaaaaaaaa") + + if "primary_key" in self.resource.json_schema_extra and self.resource.json_schema_extra["primary_key"] is True: + if self.parent.annotation._primary_key_ is not None: + raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}") + self.parent.annotation._primary_key_ = self.resource_name + self.parent.annotation._ACL_record_[self.resource_name] = [ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY) + ] + + if "plugin" in self.resource.json_schema_extra: + plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"] + if not issubclass(plugin_field, ResourcePlugin_field): + raise RestResourcePluginException_InvalidPluginSignature() + self.parent.annotation._plugins_[self.resource_name] = plugin_field + # print("ADD FIELD PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + # print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") + + else: + raise RuntimeError("fields must be contained in a RestResourceBase") + + +class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_RestResourceBase): + def process(self) -> None: + setattr(self.annotation, "_dict_key_type_", {}) + setattr(self.annotation, "_dict_value_type_", {}) + setattr(self.annotation, "_model_dump_excluded_", {}) + setattr(self.annotation, "_primary_key_", None) + setattr(self.annotation, "_plugins_", {}) + setattr(self.annotation, "_ACL_record_", {}) + + # preprocessing types / structure + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + self.parent.annotation._model_dump_excluded_[self.resource_name] = True + self.resource.exclude = True + self.parent.resource.model_rebuild(force=True) + self.parent.annotation._ACL_record_[self.resource_name] = [] + + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + ): + if "plugin" in self.resource.json_schema_extra: + plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] + if not issubclass(plugin_resource, ResourcePlugin_RestResourceBase): + raise RestResourcePluginException_InvalidPluginSignature() + self.parent.annotation._plugins_[self.resource_name] = plugin_resource + # print("ADD RESOURCE PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + # print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") + + +class RestResourceWalker_Root__tree_init(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict__tree_init, + RestResourceWalker_Sub_RestFields__tree_init, + RestResourceWalker_Sub_RestResourceBase__tree_init, + ] + + +def register_rest_rootpoint(klass: type[RestResourceBase]): + RestResourceWalker_Root__tree_init(klass).process() + return klass diff --git a/src/pyrestresource/rest_resource_walker.py b/src/pyrestresource/rest_resource_walker.py index 808eb5d..a72a834 100644 --- a/src/pyrestresource/rest_resource_walker.py +++ b/src/pyrestresource/rest_resource_walker.py @@ -1,26 +1,23 @@ from __future__ import annotations - from typing import ( cast, - Any, - Optional, Union, get_args, get_origin, TypeVar, + Type, Generic, TYPE_CHECKING, ) -from typing import Type -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod from pydantic.fields import FieldInfo from .rest_types import _T_SupportedRESTFields +from .rest_resource import RestResourceBase - -if TYPE_CHECKING: - from .rest_resource import RestResourceBase +if TYPE_CHECKING is True: + from typing import Any, Optional TV_RestResourceWalkerFutureResult = TypeVar("TV_RestResourceWalkerFutureResult") @@ -42,7 +39,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): @classmethod @abstractmethod - def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: """implementation interface to Factory. The factory will call this specialized method on each implementation to find a supported one. """ @@ -53,7 +50,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): self, subs: list[type[RestResourceWalker_Sub]], resource_name: str, - resource: FieldInfo | Type["RestResourceBase"], + resource: FieldInfo | Type[RestResourceBase], parent: Optional[RestResourceWalker_Sub] = None, argument: Optional[any] = None, ) -> Optional[RestResourceWalker_Sub]: @@ -68,15 +65,15 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): def __init__( self, resource_name: str, - resource: FieldInfo | Type["RestResourceBase"], + resource: FieldInfo | Type[RestResourceBase], parent: Optional[RestResourceWalker_Sub] = None, - annotation: Optional[type["RestResourceBase"]] = None, + annotation: Optional[type[RestResourceBase]] = None, _optional: Optional[bool] = None, argument: Optional[any] = None, ): self.argument: any = argument self.resource_name: str = resource_name - self.resource: FieldInfo | Type["RestResourceBase"] = resource + self.resource: FieldInfo | Type[RestResourceBase] = resource self.parent: Optional[RestResourceWalker_Sub] = parent self.future_results_subs: Optional[list[RestResourceWalkerFutureResult[TV_RestResourceWalkerFutureResult]]] = None @@ -85,7 +82,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): self.future_results_subs = [] self.future_result = self.cls_RestResourceWalkerFutureResult(self) - self.annotation: type["RestResourceBase"] + self.annotation: type[RestResourceBase] self.optional: bool if annotation is None or _optional is None: self.annotation, self.optional = self.ProcessAnnotation(resource) @@ -151,9 +148,9 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): @staticmethod def ProcessAnnotation( - resource: FieldInfo | Type["RestResourceBase"], + resource: FieldInfo | Type[RestResourceBase], ) -> tuple[type[Any], bool]: - from .rest_resource import RestResourceBase + # from .rest_resource import RestResourceBase _anno: Type[Any] @@ -186,7 +183,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): class RestResourceWalker_Sub_T_Dict(RestResourceWalker_Sub): @classmethod - def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: _anno, _optional = cls.ProcessAnnotation(resource) _type_resource = get_origin(_anno) return (_type_resource is dict), _anno, _optional @@ -202,7 +199,7 @@ class RestResourceWalker_Sub_T_Dict(RestResourceWalker_Sub): class RestResourceWalker_Sub_RestFields(RestResourceWalker_Sub): @classmethod - def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: _anno, _optional = cls.ProcessAnnotation(resource) return (_anno in _T_SupportedRESTFields), _anno, _optional @@ -212,9 +209,7 @@ class RestResourceWalker_Sub_RestFields(RestResourceWalker_Sub): class RestResourceWalker_Sub_RestResourceBase(RestResourceWalker_Sub): @classmethod - def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: - from .rest_resource import RestResourceBase - + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: _anno, _optional = cls.ProcessAnnotation(resource) return ( ((get_origin(_anno) is None) and issubclass(_anno, RestResourceBase)), @@ -236,11 +231,9 @@ class RestResourceWalker_Root: RestResourceWalker_Sub_RestResourceBase, ] - def __init__(self, resource: "RestResourceBase" | Type["RestResourceBase"]) -> None: + def __init__(self, resource: RestResourceBase | Type[RestResourceBase]) -> None: self.subwalker_argument: any = None - from .rest_resource import RestResourceBase - - self.resource: Type["RestResourceBase"] + self.resource: Type[RestResourceBase] if isinstance(resource, RestResourceBase): self.resource = type(resource) else: @@ -256,7 +249,7 @@ class RestResourceWalker_Root: if sub_walker_initial is not None: sub_walker_initial.process() sub_walker_initial.get_future() - resource_list: list[tuple[str, FieldInfo | Type["RestResourceBase"], RestResourceWalker_Sub]] = [ + resource_list: list[tuple[str, FieldInfo | Type[RestResourceBase], RestResourceWalker_Sub]] = [ (subresource_name, subresource, sub_walker_initial) for subresource_name, subresource in sub_walker_initial.get_sub_resources() ] diff --git a/src/pyrestresource/rest_types.py b/src/pyrestresource/rest_types.py index 98edc7c..6e0a7f8 100644 --- a/src/pyrestresource/rest_types.py +++ b/src/pyrestresource/rest_types.py @@ -1,14 +1,20 @@ # pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring from __future__ import annotations -from enum import Enum, auto from typing import Union, get_origin, NewType, TypeVar, TYPE_CHECKING + +from enum import Enum, auto from datetime import datetime from pathlib import Path from uuid import UUID from ipaddress import IPv4Address, IPv4Network -if TYPE_CHECKING: - from .rest_resource import RestResourceBase +if TYPE_CHECKING is True: + pass + + +class RestResourceException(Exception): + pass + T_Gen_DictKeys: type = type({}.keys()) NoneType = type(None) diff --git a/test/test_rest_login.py b/test/test_rest_login.py index 96e451a..f3fb471 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -12,6 +12,7 @@ import socket import requests from contextlib import closing from multiprocessing import Process +from requests.adapters import HTTPAdapter print(__name__) print(__package__) @@ -43,20 +44,30 @@ chdir(testdir_path.parent.resolve()) # to allow mock-ing, all the tested classes are in a function def init_classes(): - user_CHACHA = UserLogin(username="chacha", secret="123456") + user_test = UserLogin(username="TestUser", secret="123456") + + class TestResource(RestResourceBase): + test_field: Optional[str] = Field("ORIGIN_VALUE") class TestResourceACL(RestResourceBase): test_field: Optional[str] = Field( "ORIGIN_VALUE", ACL=[ - ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user(name="chacha"), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test), rule=ACL_rule.ALLOW), ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), ], ) @register_rest_rootpoint class RootApp(RestResourceBaseLogin): - _ar_user_login: ClassVar[list[UserLogin]] = [user_CHACHA] + _ar_user_login: ClassVar[list[UserLogin]] = [user_test] + test_resourceACL: TestResource = Field( + TestResource(), + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user(name=user_test.username), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) test_resource: TestResourceACL = TestResourceACL() # this add the classes to globals to allow using them later on @@ -73,21 +84,16 @@ def find_free_port(): def launch_server(ip, port): - print(f"port2={port}") init_classes() uvicorn.run(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True) -class Test_RestAPI_LOGIN(unittest.TestCase): +class Test_RestAPI_LOGIN_Web(unittest.TestCase): def setUp(self) -> None: chdir(testdir_path.parent.resolve()) - init_classes() - self.testapp = RootApp() - def test_access(self): + def test_login(self): ip, port = find_free_port() - print(f"ip1={ip}") - print(f"port1={port}") proc = Process( target=launch_server, args=( @@ -98,6 +104,139 @@ class Test_RestAPI_LOGIN(unittest.TestCase): proc.start() sleep(1) s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "TestUser"}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TestUser") + + finally: + proc.terminate() + s.close() + + def test_access_resourceACL(self): + ip, port = find_free_port() + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # before modification read + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resourceACL/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 500) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resourceACL", json={"test_field": "TEST SET VALUE"}) + self.assertEqual(response.status_code, 500) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resourceACL/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + # authenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resourceACL", json={"test_field": "TEST SET VALUE 2"}) + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") + + finally: + proc.terminate() + s.close() + + def test_access_fieldACL(self): + ip, port = find_free_port() + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + try: # before modification read response = s.get( @@ -106,8 +245,19 @@ class Test_RestAPI_LOGIN(unittest.TestCase): self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), "ORIGIN_VALUE") - # try unauthenticated write - response = s.put(f"http://{ip}:{port}/test_resource/test_field", json='"TEST SET VALUE"') + # try unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 500) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resource", json={"test_field": "TEST SET VALUE"}) self.assertEqual(response.status_code, 500) # check not modified @@ -120,11 +270,11 @@ class Test_RestAPI_LOGIN(unittest.TestCase): # login response = s.put( f"http://{ip}:{port}/login", - json={"username": "chacha", "secret": "123456"}, + json={"username": "TestUser", "secret": "123456"}, ) self.assertEqual(response.status_code, 201) - # authenticated write + # authenticated write (to field) response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") self.assertEqual(response.status_code, 201) @@ -135,73 +285,16 @@ class Test_RestAPI_LOGIN(unittest.TestCase): self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), "TEST SET VALUE") - finally: - proc.terminate() - s.close() - - def test_login(self): - result = self.testapp.process_request("/login", rsrc_verb.GET) - print("*****************") - print(result.get_result()) - - result = self.testapp.process_request("/login/username", rsrc_verb.GET) - print("*****************") - print(result.get_result()) - - # result = self.testapp.process_request("/login/secret", rsrc_verb.GET) - # print("*****************") - # print(result.get_result()) - - result = self.testapp.process_request("/login", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') - print("*****************") - print(result.get_result()) - - result = self.testapp.process_request("/login", rsrc_verb.GET) - print("*****************") - print(result.get_result()) - - result = self.testapp.process_request("/login/username", rsrc_verb.GET) - print("*****************") - print(result.get_result()) - - # result = self.testapp.process_request("/login/secret", rsrc_verb.GET) - # print("*****************") - # print(result.get_result()) - - -class Test_RestAPI_LOGIN_Web(unittest.TestCase): - def setUp(self) -> None: - chdir(testdir_path.parent.resolve()) - - def test_login(self): - ip, port = find_free_port() - print(f"ip1={ip}") - print(f"port1={port}") - proc = Process( - target=launch_server, - args=( - ip, - port, - ), - ) - proc.start() - sleep(1) - s = requests.Session() - try: - # Login in - - response = s.put( - f"http://{ip}:{port}/login", - json={"username": "chacha", "secret": "123456"}, - ) - print(response) - print("??????") - print(response.headers) + # authenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resource", json={"test_field": "TEST SET VALUE 2"}) self.assertEqual(response.status_code, 201) - response = s.get(f"http://{ip}:{port}/login") - - response = s.get(f"http://{ip}:{port}/") + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") finally: proc.terminate() diff --git a/test/test_rest_resource_plugins.py b/test/test_rest_resource_plugins.py index cffa355..e797b13 100644 --- a/test/test_rest_resource_plugins.py +++ b/test/test_rest_resource_plugins.py @@ -16,6 +16,7 @@ from src.pyrestresource import ( T_SupportedRESTFields, ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, + RestResourcePluginException_InvalidPluginSignature, ) testdir_path = Path(__file__).parent.resolve() @@ -34,6 +35,7 @@ def init_classes(): class ResourcePlugin_Info(ResourcePlugin_RestResourceBase_default): def handle_resource_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get: + print("HOOK GET !!") return Info_get(version="65.45", api_version="98.321") class Info_get(RestResourceBase): @@ -69,41 +71,9 @@ def init_classes(): def init_bad_plugin1(): - # plugin with missing handle_resource_put() method + # plugin not inheriting from the right base type class ResourcePlugin_TestResource: - def handle_field_get(self, resource: TestResource, params: RestRequestParams_GET) -> TestResource: - return resource - - class TestResource(RestResourceBase): - tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] - - @register_rest_rootpoint - class RootApp2(RestResourceBase): - test: TestResource = Field(default=TestResource(tetvaluestr="testvalue")) - - RootApp2() - - -def init_bad_plugin2(): - # plugin with missing handle_resource_get() method - class ResourcePlugin_TestResource: - def handle_field_put(self, resource: TestResource, params: RestRequestParams_PUT) -> TestResource: - return resource - - class TestResource(RestResourceBase): - tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] - - @register_rest_rootpoint - class RootApp2(RestResourceBase): - test: TestResource = Field(default=TestResource(tetvaluestr="testvalue")) - - RootApp2() - - -def init_bad_plugin3(): - # wrong plugin - class ResourcePlugin_TestResource(ResourcePlugin_RestResourceBase_default): - pass + ... class TestResource(RestResourceBase): tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] @@ -200,9 +170,5 @@ class Test_RestAPI_Plugin_GET(unittest.TestCase): self.assertEqual(result.get_result(), '"1.5.6"') def test_defect_plugin_field(self): - with self.assertRaises(RuntimeError): + with self.assertRaises(RestResourcePluginException_InvalidPluginSignature): init_bad_plugin1() - with self.assertRaises(RuntimeError): - init_bad_plugin2() - with self.assertRaises(RuntimeError): - init_bad_plugin3() diff --git a/test/test_rest_webserver.py b/test/test_rest_webserver.py index 6eebc21..9329214 100644 --- a/test/test_rest_webserver.py +++ b/test/test_rest_webserver.py @@ -13,7 +13,7 @@ import socket import requests from contextlib import closing from multiprocessing import Process - +from requests.adapters import HTTPAdapter print(__name__) print(__package__) @@ -121,7 +121,6 @@ def find_free_port(): def launch_server(ip, port): - print(f"port2={port}") init_classes() uvicorn.run(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True) @@ -132,8 +131,6 @@ class Test_RestAPI_WebServer(unittest.TestCase): def test_nomal_AllCmd_games(self): ip, port = find_free_port() - print(f"ip1={ip}") - print(f"port1={port}") proc = Process( target=launch_server, args=( @@ -144,6 +141,8 @@ class Test_RestAPI_WebServer(unittest.TestCase): proc.start() sleep(1) s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + try: # Fetching games response = s.get(f"http://{ip}:{port}/games") @@ -291,8 +290,6 @@ class Test_RestAPI_WebServer(unittest.TestCase): n_loop = 10000 ip, port = find_free_port() - print(f"ip1={ip}") - print(f"port1={port}") proc = Process( target=launch_server, args=( @@ -303,6 +300,8 @@ class Test_RestAPI_WebServer(unittest.TestCase): proc.start() sleep(1) s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + try: start = time() for _ in range(n_loop): -- 2.47.3 From 4cc50808382cce01b3230ea4eb0ff59c48c29ef3 Mon Sep 17 00:00:00 2001 From: cclecle Date: Sun, 5 Nov 2023 22:11:21 +0000 Subject: [PATCH 10/20] add a feature to keep exception on when called from python / but not when called from uvicorn. --- src/pyrestresource/__init__.py | 19 +- src/pyrestresource/helpers.py | 8 + src/pyrestresource/rest_exceptions.py | 58 +++ src/pyrestresource/rest_login.py | 42 +- src/pyrestresource/rest_request.py | 71 +++- src/pyrestresource/rest_resource.py | 99 +++-- src/pyrestresource/rest_resource_handler.py | 69 ++-- src/pyrestresource/rest_resource_plugin.py | 30 +- src/pyrestresource/rest_resource_rootpoint.py | 20 +- src/pyrestresource/rest_resource_walker.py | 36 +- src/pyrestresource/rest_types.py | 7 +- test/test_ACL.py | 26 +- test/test_rest_login.py | 376 +++++++++++++++++- test/test_rest_resource.py | 7 +- test/test_rest_resource_plugins.py | 6 +- test/test_rest_webserver.py | 9 - 16 files changed, 678 insertions(+), 205 deletions(-) create mode 100644 src/pyrestresource/rest_exceptions.py diff --git a/src/pyrestresource/__init__.py b/src/pyrestresource/__init__.py index d129cc3..08402e3 100644 --- a/src/pyrestresource/__init__.py +++ b/src/pyrestresource/__init__.py @@ -32,7 +32,6 @@ if TYPE_CHECKING: T_T_DictKey, T_DictValues, T_T_DictValues, - RestResourceException, ) from .rest_request_opt import ( @@ -51,8 +50,20 @@ from .rest_resource_plugin import ( ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, ResourcePlugin_dict_default, - RestResourcePluginException, - RestResourcePluginException_InvalidPluginSignature, ) from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule -from .rest_login import RestResourceBaseLogin, UserLogin +from .rest_login import ( + RestResourceBaseLogin, + UserLogin, +) + +from .rest_exceptions import ( + RestResourceException, + RestResourceLoginException, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_ClientChange, + RestResourceLoginException_InvalidSession, + RestResourcePluginException, + RestResourcePluginException_InvalidPluginSignature, + RestResourceHandlerException_Forbiden, +) diff --git a/src/pyrestresource/helpers.py b/src/pyrestresource/helpers.py index e8c13d1..2d2723e 100644 --- a/src/pyrestresource/helpers.py +++ b/src/pyrestresource/helpers.py @@ -4,6 +4,7 @@ from __future__ import annotations from uuid import UUID import json +import traceback from .rest_types import T_Gen_DictKeys @@ -30,3 +31,10 @@ def parse_dict_cookies(cookies: str) -> dict[str, str]: name, value = item.split("=", 1) result[name] = value return result + + +def forward_exception(e: Exception, forward: bool) -> None: + if forward: + raise e from None + else: + traceback.print_exc() diff --git a/src/pyrestresource/rest_exceptions.py b/src/pyrestresource/rest_exceptions.py new file mode 100644 index 0000000..c779dac --- /dev/null +++ b/src/pyrestresource/rest_exceptions.py @@ -0,0 +1,58 @@ +class RestResourceException(Exception): + pass + + +class RestResourceModelException(RestResourceException): + pass + + +class RestResourceModelException_ACL(RestResourceModelException): + pass + + +class RestResourceHandlerException(RestResourceException): + pass + + +class RestResourceHandlerException_ResourceNotFound(RestResourceHandlerException): + pass + + +class RestResourceHandlerException_MethodNotAllowed(RestResourceHandlerException): + pass + + +class RestResourceHandlerException_BadRequest(RestResourceHandlerException): + pass + + +class RestResourceHandlerException_Forbiden(RestResourceHandlerException): + pass + + +class RestResourceLoginException(RestResourceException): + pass + + +class RestResourceLoginException_SessionTimeout(RestResourceLoginException): + pass + + +class RestResourceLoginException_ClientChange(RestResourceLoginException): + pass + + +class RestResourceLoginException_InvalidSession(RestResourceLoginException): + pass + + +class RestResourceLoginException_InvalidCredentials(RestResourceLoginException): + pass + + +class RestResourcePluginException(RestResourceException): + pass + + +class RestResourcePluginException_InvalidPluginSignature(RestResourcePluginException): + pass diff --git a/src/pyrestresource/rest_login.py b/src/pyrestresource/rest_login.py index 707e421..f91de8c 100644 --- a/src/pyrestresource/rest_login.py +++ b/src/pyrestresource/rest_login.py @@ -15,13 +15,19 @@ from __future__ import annotations from typing import Optional, ClassVar, TYPE_CHECKING from secrets import token_hex, compare_digest -from datetime import datetime +from datetime import datetime, timedelta from pydantic import BaseModel, Field from .rest_types import rsrc_verb from .rest_resource import RestResourceBase from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule, ACL_target_user from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default +from .rest_exceptions import ( + RestResourceLoginException_InvalidCredentials, + RestResourceLoginException_ClientChange, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_InvalidSession, +) if TYPE_CHECKING is True: from .rest_request import RestRequest, RestRequestParams_GET @@ -35,7 +41,7 @@ class UserLogin(BaseModel): class UserSession(BaseModel): last_update: datetime user_login: UserLogin - host: Optional[str] + client: Optional[tuple[str, int]] class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): @@ -65,30 +71,38 @@ class Login(RestResourceBase): class RestResourceBaseLogin(RestResourceBase): _ar_user_login: ClassVar[list[UserLogin]] = [] _ar_user_session: dict[str, UserSession] = {} - _max_session_time_minutes: ClassVar[int] = 20 + _max_session_inactive: ClassVar[timedelta] = timedelta(minutes=20) + _max_session_time: ClassVar[timedelta] = timedelta(hours=12) login: Login = Field(default=Login(), plugin=ResourcePlugin_Login) + def get_new_cookie_expiration_date(self) -> datetime: + return datetime.now() + self._max_session_time + def _process_request_session(self, request: RestRequest) -> None: + # print(f"[TRACE] {type(self).__name__}->_process_request_session()") + # print(f"[TRACE] request: {id(request)}") auth_cookie = request.get_cookie("Authorization") if auth_cookie != None: if auth_cookie in self._ar_user_session: - print("USER SESSION FOUND !") - print(self._ar_user_session[auth_cookie].user_login.username) - print(auth_cookie) + # print(f"SESSION FOUND for {request.get_client()}") - time_diff_min = (datetime.now() - self._ar_user_session[auth_cookie].last_update).total_seconds() / 60 - - if time_diff_min > self._max_session_time_minutes: + if self._ar_user_session[auth_cookie].client != request.get_client(): del self._ar_user_session[auth_cookie] - raise RuntimeError("session timeout ! (session reseted)") + raise RestResourceLoginException_ClientChange() + + time_diff = datetime.now() - self._ar_user_session[auth_cookie].last_update + if time_diff > self._max_session_inactive: + del self._ar_user_session[auth_cookie] + raise RestResourceLoginException_SessionTimeout() request.set_user(ACL_target_user(name=self._ar_user_session[auth_cookie].user_login.username)) + # print("SESSION RECOVERED") return - print("Invalid session") + raise RestResourceLoginException_InvalidSession() return - print("non-connected user") + # print(f"non-connected user {request.get_client()}") def user_login(self, user_name: str, user_secret: str, request: RestRequest) -> str: already_failed: bool = False @@ -107,10 +121,10 @@ class RestResourceBaseLogin(RestResourceBase): pass if already_failed: - raise RuntimeError("Wrong auth") # TODO: specific exception + raise RestResourceLoginException_InvalidCredentials() def _register_user_session(self, user_login: UserLogin, request: RestRequest) -> str: token = token_hex(16) - new_user_session = UserSession(last_update=datetime.now(), user_login=user_login, host=request.get_host()) + new_user_session = UserSession(last_update=datetime.now(), user_login=user_login, client=request.get_client()) self._ar_user_session[f"Bearer {token}"] = new_user_session return token diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index 0d37b7d..e7dd54c 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -31,6 +31,7 @@ from .helpers import parse_dict_cookies if TYPE_CHECKING is True: from typing import Optional from .rest_types import T_SupportedRESTFields + from .rest_resource import RestResourceBase class RequestFactory( @@ -49,7 +50,9 @@ class RequestFactory( cls_RestRequestParams_POST: type[RestRequestParams_POST] = Field(default=RestRequestParams_POST) cls_RestRequestParams_DELETE: type[RestRequestParams_DELETE] = Field(default=RestRequestParams_DELETE) - def get_RestRequest(self, url: str, verb: rsrc_verb, data: dict, query_string: Optional[str] = None) -> RestRequest: + def get_RestRequest( + self, root_resource: RestResourceBase, url: str, verb: rsrc_verb, data: dict, query_string: Optional[str] = None + ) -> RestRequest: """get a RestRequets instance based on LUT_verb configuration Args: @@ -60,14 +63,14 @@ class RequestFactory( # /!\ mypy seems not being able to propagate typevar to composed classes if verb is rsrc_verb.GET: - return RestRequest[RestRequestParams_GET](self.cls_RestRequestParams_GET, url, verb, data, query_string) + return RestRequest[RestRequestParams_GET](self.cls_RestRequestParams_GET, root_resource, url, verb, data, query_string) if verb is rsrc_verb.PUT: - return RestRequest[RestRequestParams_PUT](self.cls_RestRequestParams_PUT, url, verb, data, query_string) + return RestRequest[RestRequestParams_PUT](self.cls_RestRequestParams_PUT, root_resource, url, verb, data, query_string) if verb is rsrc_verb.POST: - return RestRequest[RestRequestParams_POST](self.cls_RestRequestParams_POST, url, verb, data, query_string) + return RestRequest[RestRequestParams_POST](self.cls_RestRequestParams_POST, root_resource, url, verb, data, query_string) if verb is rsrc_verb.DELETE: - return RestRequest[RestRequestParams_DELETE](self.cls_RestRequestParams_DELETE, url, verb, data, query_string) - raise RuntimeError("Invalid Verb") + return RestRequest[RestRequestParams_DELETE](self.cls_RestRequestParams_DELETE, root_resource, url, verb, data, query_string) + raise RestResourceHandlerException_MethodNotAllowed("Invalid Verb") def update_RestRequest(self, request: RestRequest) -> None: """create an updated copy of a RestRequest object based on a different LUT_verb configuration @@ -85,7 +88,7 @@ class RequestFactory( elif request.verb is rsrc_verb.DELETE: request.update_ReqParams(self.cls_RestRequestParams_DELETE) else: - raise RuntimeError("Invalid Verb") + raise RestResourceHandlerException_MethodNotAllowed("Invalid Verb") return @@ -96,12 +99,11 @@ class RestRequest(Generic[_T_RestRequestParams]): def __init__( self, type_request_params: type[_T_RestRequestParams], + root_resource: RestResourceBase, url: str, verb: rsrc_verb, data: Optional[dict[str, T_SupportedRESTFields]] = None, query_string: Optional[str] = None, - incoming_cookie: dict[str, str] = {}, - outgoing_cookie: dict[str, str] = {}, ) -> None: """class to handle a request context, that will be kept and updated while walking url parts @@ -118,27 +120,29 @@ class RestRequest(Generic[_T_RestRequestParams]): self.url: str self.verb: rsrc_verb self.data: dict - self.raw_headers: list[Any] + self._raw_headers: list[Any] = [] + self._client: tuple[str, int] = () self.headers: dict[str, None | str | dict[str, None | str]] = {"host": None, "cookie": {}} self._saved_url_params: dict self.ReqParams: _T_RestRequestParams = type_request_params() self.url_stack: list[str] self._saved_url_stack: list[str] self.url_stack_index: int - self.incoming_cookie: dict[str, str] = incoming_cookie - self.outgoing_cookie: dict[str, str] = outgoing_cookie + self.outgoing_cookie: dict[str, str] = {} self.user: ACL_target_user = ACL_target_user_Annonymous() self.groups: list[ACL_target_group] = [] self.result: Optional[str] = None + self._forced_status: Optional[int] = None + self.root_resource: RestResourceBase = root_resource # = or create a fresh one = if url is None or verb is None or data is None: - raise RuntimeError("url and verb and data must be set") + raise RestResourceException("url and verb and data must be set") self.url = url self.verb = verb if data != {} and not check_type(data, T_AllSupportedFields): - raise RuntimeError(f"Wrong data type received: {data}") + raise RestResourceHandlerException_BadRequest(f"Wrong data type received: {data}") self.data = data @@ -157,13 +161,34 @@ class RestRequest(Generic[_T_RestRequestParams]): self._saved_url_stack = self.url_stack.copy() self.url_stack_index = 0 + def set_resp_status(self, status: int) -> None: + self._forced_status = status + + def get_root_resource(self) -> RestResourceBase: + return self.root_resource + + def get_status(self) -> int: + if self._forced_status is not None: + return self._forced_status + + if self.verb in (rsrc_verb.POST, rsrc_verb.PUT): + return 201 + + return 200 + + def set_client(self, client: tuple[str, int]) -> None: + self._client = client + + def get_client(self) -> tuple[str, int]: + return self._client + def set_headers(self, headers: list[Any]) -> None: - self.raw_headers = headers - for elem in self.raw_headers: + self._raw_headers = headers + for elem in self._raw_headers: if elem[0] == b"host": self.headers["host"] = elem[1].decode("utf-8") - # elif elem[0] == b"user-agent": - # self.headers["user-agent"] = elem[1].decode("utf-8") + elif elem[0] == b"user-agent": + self.headers["user-agent"] = elem[1].decode("utf-8") elif elem[0] == b"cookie": self.headers["cookie"] = parse_dict_cookies(elem[1].decode("utf-8")) @@ -172,8 +197,16 @@ class RestRequest(Generic[_T_RestRequestParams]): return None return self.headers["cookie"][key] + def set_resp_cookie_value(self, key: str, value: str) -> None: + self.outgoing_cookie[ + key + ] = f"{value}; expires={self.root_resource.get_new_cookie_expiration_date().strftime('%a, %d %b %Y %H:%M:%S GMT')}; path=/; HttpOnly" + + def reset_resp_cookie(self, key: str) -> None: + self.outgoing_cookie[key] = "null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT" + def get_host(self) -> str: - print(self.headers["host"]) + return self.headers["host"] def set_result(self, result: str): self.result = result diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index bd42c32..b7c00c0 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -8,9 +8,11 @@ from typing import ( from abc import ABC import json +import pprint + from pydantic import BaseModel -from .helpers import _JSONEncoder +from .helpers import _JSONEncoder, forward_exception from .rest_types import rsrc_verb, _T_SupportedRESTFields from .rest_ACL import ( @@ -22,6 +24,17 @@ from .rest_ACL import ( ) from .rest_request import RestRequest +from .rest_exceptions import ( + RestResourceLoginException_InvalidSession, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_ClientChange, + RestResourceLoginException_InvalidCredentials, + RestResourceHandlerException_ResourceNotFound, + RestResourceHandlerException_MethodNotAllowed, + RestResourceHandlerException_BadRequest, + RestResourceHandlerException_Forbiden, + RestResourceException, +) if TYPE_CHECKING is True: from .rest_types import ( @@ -64,15 +77,15 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if acl.rule is ACL_rule.ALLOW: # print("ALLOWED (user)") return - raise RuntimeError(f"Not allowed access detected: {field}") + raise RestResourceHandlerException_Forbiden(f"Not allowed access detected: {field}") elif isinstance(acl.target, ACL_target_group): if isinstance(acl.target, ACL_target_group_Any) or any(_ for _ in groups if _.name == acl.target.name): if acl.rule is ACL_rule.ALLOW: # print("ALLOWED (group)") return - raise RuntimeError(f"Not allowed access detected: {field}") + raise RestResourceHandlerException_Forbiden(f"Not allowed access detected: {field}") else: - raise RuntimeError(f"Wrong ACL target type: {field}") + raise RestResourceException(f"Wrong ACL target type: {field}") # print("ALLOWED (Default)") def check_acl_field(self, request: RestRequest, req_index: int = 0) -> None: @@ -89,7 +102,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if key in self.model_fields: self._check_acl(request.user, request.groups, rsrc_verb.PUT, key) else: - raise RuntimeError("Incompatible verb") + raise RestResourceException("Incompatible verb") def update(self, **new_data): for field, value in new_data.items(): @@ -119,27 +132,25 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if b"content-type" in scope["headers"]: assert scope["headers"][b"content-type"] == b"application/json" - # import pprint - - # print("----REC HEADER ---") - # pprint.pprint(scope["headers"]) + # pprint.pprint(scope) body = await self.read_body(receive) - verb = rsrc_verb[scope["method"]] request: RestRequest = self.process_request( - scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8"), scope["headers"] + scope["path"], + rsrc_verb[scope["method"]], + body.decode("utf-8"), + scope["query_string"].decode("utf-8"), + scope["client"], + scope["headers"], + True, ) - assert request != None - - status = 200 - if verb in (rsrc_verb.POST, rsrc_verb.PUT): - status = 201 + assert request is not None header_resp = { "type": "http.response.start", - "status": status, + "status": request.get_status(), "headers": [ [b"content-type", b"application/json"], ], @@ -148,8 +159,6 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): for name, value in request.outgoing_cookie.items(): header_resp["headers"].append(["Set-Cookie", f"{name}={value}"]) - # print("----SENT HEADER ---") - # pprint.pprint(header_resp) await send(header_resp) body = None @@ -172,7 +181,9 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): verb: rsrc_verb = rsrc_verb.GET, data_json: Optional[str] = None, query_string: Optional[str] = None, + client: Optional[tuple[str, int]] = None, headers: Optional[list[Any]] = None, + http_mode: bool = False, ) -> RestRequest: from .rest_resource_handler import ( ResourceHandler, @@ -188,22 +199,50 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): # preparing request & session request: RestRequest = ressource_handler.get_request() + assert request is not None + if headers is not None: request.set_headers(headers) + + if client is not None: + request.set_client(client) + + try: self._process_request_session(request) - # processing the verb - result = ressource_handler.process_verb() + result = ressource_handler.process_verb() - # print("OOO") - # print(type(self)._resp_cookies) - # print("OOO2") + if isinstance(result, RestResourceBase): + request.set_result(json.dumps(result.model_dump(mode="json"))) + elif result is not None: + request.set_result(json.dumps(result, cls=_JSONEncoder)) + else: + request.set_result("null") - if isinstance(result, RestResourceBase): - request.set_result(json.dumps(result.model_dump(mode="json"))) - elif result is not None: - request.set_result(json.dumps(result, cls=_JSONEncoder)) - else: - request.set_result("null") + except RestResourceHandlerException_ResourceNotFound as e: + request.set_resp_status(404) + forward_exception(e, not http_mode) + + except RestResourceHandlerException_MethodNotAllowed as e: + request.set_resp_status(405) + forward_exception(e, not http_mode) + + except RestResourceHandlerException_BadRequest as e: + request.set_resp_status(400) + forward_exception(e, not http_mode) + + except RestResourceHandlerException_Forbiden as e: + request.set_resp_status(403) + forward_exception(e, not http_mode) + + except ( + RestResourceLoginException_InvalidSession, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_ClientChange, + RestResourceLoginException_InvalidCredentials, + ) as e: + request.set_resp_status(401) + request.reset_resp_cookie("Authorization") + forward_exception(e, not http_mode) return request diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index 77cfd42..3197c53 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -33,6 +33,14 @@ from .rest_request_opt import ( _T_RestRequestParams_PUT, ) +from .rest_exceptions import ( + RestResourceHandlerException, + RestResourceHandlerException_ResourceNotFound, + RestResourceHandlerException_MethodNotAllowed, + RestResourceHandlerException_BadRequest, + RestResourceHandlerException_Forbiden, +) + if TYPE_CHECKING is True: from .rest_types import T_T_DictKey, T_T_DictValues from .rest_request import RestRequest @@ -83,7 +91,6 @@ class ResourceHandler( self.next_handler: Optional[ResourceHandler] = None self.saved_url: list[str] = [] self.resource: _T_Resource = resource - self.root_resource: _T_Resource = resource if prev_handler is None else prev_handler.root_resource self.req: RestRequest if prev_handler is not None: self.prev_handler = prev_handler @@ -91,13 +98,13 @@ class ResourceHandler( self._request_factory.update_RestRequest(self.req) elif None in [url, verb]: - raise RuntimeError("if req not set, url,verb must be setted") + raise RestResourceHandlerException("if req not set, url,verb must be setted") else: if url is None or verb is None: - raise RuntimeError("url and verb must be set") + raise RestResourceHandlerException("url and verb must be set") if data is None: data = {} - self.req = self._request_factory.get_RestRequest(url, verb, data, query_string) + self.req = self._request_factory.get_RestRequest(resource, url, verb, data, query_string) # print(f"[TRACE] creating {type(self).__name__}() with url={self.req.get_url_stack()}") @@ -116,7 +123,7 @@ class ResourceHandler( if resource_handler_cls._check_resource_handler(resource, req): # print(f"[DEBUG] match ResourceHandler: {resource_handler_cls.__name__}") return resource_handler_cls - raise RuntimeError(f"Unsupported Resource Type {type(resource).__name__}") + raise RestResourceHandlerException(f"Unsupported Resource Type {type(resource).__name__}") @classmethod def register_resource_handler(cls, other_cls) -> None: @@ -187,7 +194,7 @@ class ResourceHandler( return next_resource_handler # in _find_resource context, only resource's real values can be retrieved - raise RuntimeError("Wrong request") + raise RestResourceHandlerException_ResourceNotFound() def _check_access_rights(self): pass @@ -210,7 +217,7 @@ class ResourceHandler( self._process_delete() return None - raise RuntimeError("Invalid Verb") + raise RestResourceHandlerException_BadRequest("Invalid Verb") def _process_get( self, @@ -231,16 +238,16 @@ class ResourceHandler( self._handle_process_delete(self.req.get_req_params()) def _handle_process_get(self, params: _T_RestRequestParams_GET) -> _T_Resource | list[T_DictKey]: - raise RuntimeError(f"GET method not implemented for {type(self).__name__}") + raise RestResourceHandlerException_MethodNotAllowed(f"GET method not implemented for {type(self).__name__}") def _handle_process_put(self, params: _T_RestRequestParams_PUT) -> None: - raise RuntimeError(f"PUT method not implemented for {type(self).__name__}") + raise RestResourceHandlerException_MethodNotAllowed(f"PUT method not implemented for {type(self).__name__}") def _handle_process_post(self, params: _T_RestRequestParams_POST) -> Optional[T_DictKey]: - raise RuntimeError(f"POST method not implemented for {type(self).__name__}") + raise RestResourceHandlerException_MethodNotAllowed(f"POST method not implemented for {type(self).__name__}") def _handle_process_delete(self, params: _T_RestRequestParams_DELETE) -> None: - raise RuntimeError(f"DELETE method not implemented for {type(self).__name__}") + raise RestResourceHandlerException_MethodNotAllowed(f"DELETE method not implemented for {type(self).__name__}") @ResourceHandler.register_resource_handler @@ -289,8 +296,7 @@ class ResourceHandler_dict( # print(f"{type(self).__name__}->_handle_process_delete()") # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") - if self.prev_handler is None: - raise RuntimeError("Wrong command") + assert self.prev_handler is not None dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] @@ -308,8 +314,7 @@ class ResourceHandler_dict( # print(f"{type(self).__name__}->_handle_process_post()") # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") - if self.prev_handler is None: - raise RuntimeError("Wrong command") + assert self.prev_handler is not None dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] dict_value_type: T_T_DictValues = cast(RestResourceBase, self.prev_handler.resource)._dict_value_type_[ @@ -341,7 +346,9 @@ class ResourceHandler_dict( _dict[_obj_primary_key] = _obj return _obj_primary_key - RuntimeError("Either the object needs defined primary key or the request must contain an API_key param to process this command") + raise RestResourceHandlerException_BadRequest( + "Either the object needs defined primary key or the request must contain an API_key param to process this command" + ) return None # for mypy.... @@ -381,8 +388,7 @@ class ResourceHandler_dict_elem( # print(f"{type(self).__name__}->_process_get()") # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") - if self.prev_handler is None: - raise RuntimeError("Wrong command") + assert self.prev_handler is not None dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] @@ -401,8 +407,7 @@ class ResourceHandler_dict_elem( # instead of expected get_resource_origin(1) because we need to go backward # because self.req is another context that is not saved to improve performances - if self.prev_handler is None: - raise RuntimeError("Wrong command") + assert self.prev_handler is not None dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(2)] @@ -460,13 +465,13 @@ class ResourceHandler_RestResourceBase( # print(self.resource.exclude) if self.req.get_resource_origin(0) not in self.resource.model_fields: - raise RuntimeError(f"Unknown field access detected: {self.req.get_url_stack()}") + raise RestResourceHandlerException_ResourceNotFound(f"Unknown field access detected: {self.req.get_url_stack()}") self.resource.check_acl_field(self.req) if len(self.req.get_url_stack()) == 0: # destination reached if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET: - raise RuntimeError(f"Not allowed READ access detected: {self.req.get_url_stack()}") + raise RestResourceHandlerException_ResourceNotFound(f"Not allowed READ access detected: {self.req.get_url_stack()}") def _handle_process_get(self, params) -> RestResourceBase: # print(f"{type(self).__name__}->_process_get()") @@ -480,13 +485,13 @@ class ResourceHandler_RestResourceBase( if key in self.resource._plugins_: if issubclass(self.resource._plugins_[key], ResourcePlugin_field): plugin_field: ResourcePlugin_field = cast( - ResourcePlugin_field, self.resource._plugins_[key](self.req, self.root_resource) + ResourcePlugin_field, self.resource._plugins_[key](self.req, self.req.get_root_resource()) ) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_field_get(value, params)) elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): plugin_field: ResourcePlugin_field = cast( - ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource) + ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.req.get_root_resource()) ) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_resource_get(value, params)) @@ -509,14 +514,14 @@ class ResourceHandler_RestResourceBase( if issubclass(self.resource._plugins_[key], ResourcePlugin_field): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.resource._plugins_[key](self.req, self.root_resource), + self.resource._plugins_[key](self.req, self.req.get_root_resource()), ) value = plugin_rsrc.handle_field_get(value, params) elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.resource._plugins_[key](self.req, self.root_resource), + self.resource._plugins_[key](self.req, self.req.get_root_resource()), ) value = plugin_rsrc.handle_resource_get(value, params) @@ -539,7 +544,7 @@ class ResourceHandler_RestResourceBase( if key in _new_resrc._plugins_: if issubclass(_new_resrc._plugins_[key], ResourcePlugin_field): plugin_field: ResourcePlugin_field = cast( - ResourcePlugin_field, _new_resrc._plugins_[key](self.req, self.root_resource) + ResourcePlugin_field, _new_resrc._plugins_[key](self.req, self.req.get_root_resource()) ) value = getattr(_new_resrc, key) setattr(_new_resrc, key, plugin_field.handle_field_put(value, params)) @@ -555,7 +560,7 @@ class ResourceHandler_RestResourceBase( if key in self.prev_handler.prev_handler.resource._plugins_: plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.prev_handler.prev_handler.resource._plugins_[key](self.req, self.root_resource), + self.prev_handler.prev_handler.resource._plugins_[key](self.req, self.req.get_root_resource()), ) _new_resrc = plugin_rsrc.handle_dict_elem_put(_new_resrc, params) # element is within a RestResourceBase @@ -564,7 +569,7 @@ class ResourceHandler_RestResourceBase( if key in self.prev_handler.resource._plugins_: plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.prev_handler.resource._plugins_[key](self.req, self.root_resource), + self.prev_handler.resource._plugins_[key](self.req, self.req.get_root_resource()), ) _new_resrc = plugin_rsrc.handle_resource_put(_new_resrc, params) @@ -584,7 +589,7 @@ class ResourceHandler_RestResourceBase( ): self.prev_handler._process_delete() else: - raise RuntimeError("cannot delete an element outside a dict") + raise RestResourceHandlerException_BadRequest("cannot delete an element outside a dict") @ResourceHandler.register_resource_handler @@ -615,7 +620,7 @@ class ResourceHandler_simple( if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, - self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource), + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.req.get_root_resource()), ) return plugin_simple.handle_field_get(self.resource, params) @@ -636,7 +641,7 @@ class ResourceHandler_simple( # print("PLUGIN FOUND") plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, - self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource), + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.req.get_root_resource()), ) # print(value) value = plugin_simple.handle_field_put(value, params) diff --git a/src/pyrestresource/rest_resource_plugin.py b/src/pyrestresource/rest_resource_plugin.py index 4e7bcea..6a97876 100644 --- a/src/pyrestresource/rest_resource_plugin.py +++ b/src/pyrestresource/rest_resource_plugin.py @@ -2,16 +2,17 @@ from __future__ import annotations from typing import Optional, Generic, TYPE_CHECKING from abc import abstractmethod, ABC +from datetime import datetime from .rest_types import ( _T_DictValues, _T_DictKey, TV_SupportedRESTFields, TV_RestResourceBase, - RestResourceException, ) from .rest_request import RestRequest + if TYPE_CHECKING is True: from .rest_resource import RestResourceBase from .rest_request_opt import ( @@ -27,14 +28,6 @@ if TYPE_CHECKING is True: ) -class RestResourcePluginException(RestResourceException): - pass - - -class RestResourcePluginException_InvalidPluginSignature(RestResourcePluginException): - pass - - class ResourcePlugin(ABC): def __init__(self, request: RestRequest, root_resource: RestResourceBase) -> None: self.__request: RestRequest = request @@ -46,16 +39,17 @@ class ResourcePlugin(ABC): def get_user_login(self) -> str: return self.__request.get_user().name - def getr_req_cookie_value(self, key: str) -> Optional[str]: - return self.__request.incoming_cookie[key] + def set_resp_cookie_value(self, key: str, value: str) -> None: + self.__request.set_resp_cookie_value(key, value) - def set_resp_cookie_value(self, key: str, value: str): - # print("AAA") - # print(name) - # print(value) - # print(self.cookies) - # print(type(self.cookies)) - self.__request.outgoing_cookie[key] = value + def reset_resp_cookie(self, key: str) -> None: + self.__request.reset_resp_cookie(key) + + def get_new_cookie_expiration_date(self) -> datetime: + return self.__root_resource.get_new_cookie_expiration_date() + + def set_resp_status(self, status: int) -> None: + self.__request.set_resp_status(status) class ResourcePlugin_field(ResourcePlugin, Generic[TV_SupportedRESTFields]): diff --git a/src/pyrestresource/rest_resource_rootpoint.py b/src/pyrestresource/rest_resource_rootpoint.py index a13e521..ac62b58 100644 --- a/src/pyrestresource/rest_resource_rootpoint.py +++ b/src/pyrestresource/rest_resource_rootpoint.py @@ -12,7 +12,6 @@ from .rest_resource_plugin import ( ResourcePlugin_field, ResourcePlugin_RestResourceBase, ResourcePlugin_dict, - RestResourcePluginException_InvalidPluginSignature, ) from .rest_resource_walker import ( RestResourceWalker_Root, @@ -26,6 +25,7 @@ from .rest_ACL import ( ACL_target_group_Any, ACL_rule, ) +from .rest_exceptions import RestResourcePluginException_InvalidPluginSignature, RestResourceModelException, RestResourceModelException_ACL if TYPE_CHECKING is True: pass @@ -37,9 +37,9 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): # checking compatibility if not get_origin(datatype[1]) is None: - raise RuntimeError("complex dict types are not supported (should create a RestResourceBase container)") + raise RestResourceModelException("complex dict types are not supported (should create a RestResourceBase container)") if not datatype[0] in _T_SupportedRESTFields: - raise RuntimeError(f"Unsupported Dict Field value type in class (key)") + raise RestResourceModelException(f"Unsupported Dict Field value type in class (key)") # preprocessing types / structure if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): @@ -69,10 +69,10 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): # print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}") self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] else: - raise RuntimeError("ACL must be a list()") + raise RestResourceModelException_ACL("ACL must be a list()") else: - raise RuntimeError("dict must be contained in a RestResourceBase") + raise RestResourceModelException("dict must be contained in a RestResourceBase") class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFields): @@ -96,7 +96,9 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi if "primary_key" in self.resource.json_schema_extra and self.resource.json_schema_extra["primary_key"] is True: if self.parent.annotation._primary_key_ is not None: - raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}") + raise RestResourceModelException( + f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}" + ) self.parent.annotation._primary_key_ = self.resource_name self.parent.annotation._ACL_record_[self.resource_name] = [ ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY) @@ -114,10 +116,10 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi # print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}") self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] else: - raise RuntimeError("ACL must be a list()") + raise RestResourceModelException_ACL("ACL must be a list()") else: - raise RuntimeError("fields must be contained in a RestResourceBase") + raise RestResourceModelException("fields must be contained in a RestResourceBase") class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_RestResourceBase): @@ -153,7 +155,7 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_ # print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}") self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] else: - raise RuntimeError("ACL must be a list()") + raise RestResourceModelException_ACL("ACL must be a list()") class RestResourceWalker_Root__tree_init(RestResourceWalker_Root): diff --git a/src/pyrestresource/rest_resource_walker.py b/src/pyrestresource/rest_resource_walker.py index a72a834..f3d14fe 100644 --- a/src/pyrestresource/rest_resource_walker.py +++ b/src/pyrestresource/rest_resource_walker.py @@ -15,6 +15,7 @@ from pydantic.fields import FieldInfo from .rest_types import _T_SupportedRESTFields from .rest_resource import RestResourceBase +from .rest_exceptions import RestResourceModelException if TYPE_CHECKING is True: from typing import Any, Optional @@ -59,7 +60,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): if _is_valid is True: return sub(resource_name, resource, parent, _anno, _optional, argument) - raise RuntimeError(f"Incompatible Field Found: {type(resource).__name__}") + raise RestResourceModelException(f"Incompatible Field Found: {type(resource).__name__}") return None def __init__( @@ -91,35 +92,10 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): self.optional = _optional if self.annotation is None: - raise RuntimeError("Only annotated types are allowed in RestResourceBase derived classes") + raise RestResourceModelException("Only annotated types are allowed in RestResourceBase derived classes") self.subdatatype = get_args(self.annotation) - """ - def info(self) -> None: - print(f"{type(self).__name__}->info()") - print("==========================") - print(f"resource_name: {self.resource_name}") - print(f"resource: {type(self.resource).__name__}") - print(f"resource: {self.resource}") - print(f"parent: {self.parent}") - print(f"annotation: {self.annotation}") - print(f"optional: {self.optional}") - print(f"subdatatype: {self.subdatatype}") - - # -> cannot do that on dicts - # if self.parent is not None: - # print(f"_model_dump_excluded_: {self.parent.annotation._model_dump_excluded_}") - - if False: - print("------ STACK ------") - _rsrc = self.parent - while _rsrc is not None: - print(f"{id(_rsrc.annotation)}:{_rsrc.annotation}") - _rsrc = _rsrc.parent - print("-------------------") - """ - @abstractmethod def get_future(self) -> Optional[RestResourceWalkerFutureResult]: return self.future_result @@ -163,7 +139,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): elif not isinstance(resource, FieldInfo) and issubclass(resource, RestResourceBase): _anno = resource else: - raise RuntimeError("Incompatible resource type") + raise RestResourceModelException("Incompatible resource type") _datatype = get_args(_anno) _optional: bool = False @@ -176,7 +152,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): _anno = _datatype[0] _optional = True else: - raise RuntimeError("Union is only allowed to describe Optional (e.g. Union[XXX,None])") + raise RestResourceModelException("Union is only allowed to describe Optional (e.g. Union[XXX,None])") return _anno, _optional @@ -277,5 +253,5 @@ class RestResourceWalker_Root: current_deep = current_deep + 1 return sub_walker_initial.chain_process_future() else: - raise RuntimeError("Invalid Rootpoint") + raise RestResourceModelException("Invalid Rootpoint") return None diff --git a/src/pyrestresource/rest_types.py b/src/pyrestresource/rest_types.py index 6e0a7f8..da34850 100644 --- a/src/pyrestresource/rest_types.py +++ b/src/pyrestresource/rest_types.py @@ -12,10 +12,6 @@ if TYPE_CHECKING is True: pass -class RestResourceException(Exception): - pass - - T_Gen_DictKeys: type = type({}.keys()) NoneType = type(None) @@ -63,8 +59,7 @@ TV_SupportedRESTFields = TypeVar( NoneType, ) -if get_origin(T_SupportedRESTFields) is not Union: - raise RuntimeError("wrong T_SupportedRESTFields (must be flat Union)") +assert get_origin(T_SupportedRESTFields) is Union TV_RestResourceBase = TypeVar("TV_RestResourceBase", bound="RestResourceBase") diff --git a/test/test_ACL.py b/test/test_ACL.py index 5e69df8..7847676 100644 --- a/test/test_ACL.py +++ b/test/test_ACL.py @@ -5,12 +5,8 @@ from pathlib import Path from typing import Optional from pydantic import Field - -print(__name__) -print(__package__) - - from src.pyrestresource import ( + RestResourceHandlerException_Forbiden, register_rest_rootpoint, RestResourceBase, rsrc_verb, @@ -85,11 +81,11 @@ class Test_RestAPI_ACL(unittest.TestCase): result = self.testapp.process_request("/resource_ro", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_ro/version_ro", rsrc_verb.PUT, '"6.6.6"') self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3") - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_ro", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}') self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3") @@ -107,7 +103,7 @@ class Test_RestAPI_ACL(unittest.TestCase): self.assertEqual(result.get_result(), "null") self.assertEqual(self.testapp.resource_with_secret.username, None) - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) self.assertEqual(self.testapp.resource_with_secret.secret, None) @@ -122,7 +118,7 @@ class Test_RestAPI_ACL(unittest.TestCase): self.assertEqual(result.get_result(), '"chacha"') self.assertEqual(self.testapp.resource_with_secret.username, "chacha") - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) self.assertEqual(self.testapp.resource_with_secret.secret, "123456") @@ -138,13 +134,13 @@ class Test_RestAPI_ACL(unittest.TestCase): self.assertEqual(result.get_result(), '"chacha"') self.assertEqual(self.testapp.resource_with_secret.username, "chacha") - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) result = self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.PUT, '"123456"') self.assertEqual(result.get_result(), "null") - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) self.assertEqual(self.testapp.resource_with_secret.secret, "123456") @@ -160,23 +156,23 @@ class Test_RestAPI_ACL(unittest.TestCase): self.assertEqual(result.get_result(), "null") self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret_ACL/secret", rsrc_verb.GET) self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret_ACL", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) def test_subresource_ACL_field(self): - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret_ACL/username", rsrc_verb.PUT, '"chacha"') self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret_ACL/secret", rsrc_verb.PUT, '"123456"') self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) diff --git a/test/test_rest_login.py b/test/test_rest_login.py index f3fb471..9107fbc 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -1,12 +1,10 @@ from __future__ import annotations import unittest -from unittest.mock import patch from os import chdir from pathlib import Path -from typing import Optional, Annotated, ClassVar +from typing import Optional, ClassVar from pydantic import Field -from uuid import UUID, uuid4 -from time import time, sleep +from time import sleep import uvicorn import socket import requests @@ -14,10 +12,6 @@ from contextlib import closing from multiprocessing import Process from requests.adapters import HTTPAdapter -print(__name__) -print(__package__) - - from src.pyrestresource import ( ACL_target_user, UserLogin, @@ -45,6 +39,7 @@ chdir(testdir_path.parent.resolve()) # to allow mock-ing, all the tested classes are in a function def init_classes(): user_test = UserLogin(username="TestUser", secret="123456") + user_test2 = UserLogin(username="TestUser2", secret="abcdef") class TestResource(RestResourceBase): test_field: Optional[str] = Field("ORIGIN_VALUE") @@ -57,10 +52,25 @@ def init_classes(): ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), ], ) + test_field2: Optional[str] = Field( + "ORIGIN_VALUE", + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test2), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + test_field_both: Optional[str] = Field( + "ORIGIN_VALUE", + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test2), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) @register_rest_rootpoint class RootApp(RestResourceBaseLogin): - _ar_user_login: ClassVar[list[UserLogin]] = [user_test] + _ar_user_login: ClassVar[list[UserLogin]] = [user_test, user_test2] test_resourceACL: TestResource = Field( TestResource(), ACL=[ @@ -92,6 +102,113 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): def setUp(self) -> None: chdir(testdir_path.parent.resolve()) + def test_login_two_users(self): + ip, port = find_free_port() + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + # unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field2", json="TEST SET VALUE") + self.assertEqual(response.status_code, 403) + + # not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field2", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field_both", json="TEST SET VALUE 2") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field_both", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") + + # --------------------------------------- + # login 2 + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser2", "secret": "abcdef"}, + ) + self.assertEqual(response.status_code, 201) + + # unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="A TEST SET VALUE") + self.assertEqual(response.status_code, 403) + + # not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field2", json="A TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field2", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "A TEST SET VALUE") + + # previous (modified) value + response = s.get( + f"http://{ip}:{port}/test_resource/test_field_both", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field_both", json="A TEST SET VALUE 2") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field_both", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "A TEST SET VALUE 2") + + finally: + proc.terminate() + s.close() + def test_login(self): ip, port = find_free_port() proc = Process( @@ -146,6 +263,239 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): proc.terminate() s.close() + def test_change_host(self): + ip, port = find_free_port() + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s1 = requests.Session() + s1.mount("http://", HTTPAdapter(max_retries=0)) + s2 = requests.Session() + s2.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # s1 - read full login resource + response = s1.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s1 - read login username field + response = s1.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # s2 - read full login resource + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s2 - read login username field + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # login s1 + response = s1.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # s1 - read full login resource + response = s1.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "TestUser"}) + + # s1 - read login username field + response = s1.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TestUser") + + # s2 - read full login resource + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s2 - read login username field + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # s2 -> spoof s1 token + s2.cookies.update(s1.cookies) + + # s2 - read full login resource + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s2.cookies.get_dict(), {}) + + # s2 - read full login resource (reseted cookie) + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s2 -> spoof s1 token + s2.cookies.update(s1.cookies) + + # s2 - read login username field + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s2.cookies.get_dict(), {}) + + # s2 - read full login resource (reseted cookie) + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + finally: + proc.terminate() + s1.close() + s2.close() + + def test_login_wrong_pwd(self): + ip, port = find_free_port() + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + self.assertDictEqual(s.cookies.get_dict(), {}) + + # --------------------------------------------------- + # login (wrong pwd) + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "abc"}, + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + self.assertDictEqual(s.cookies.get_dict(), {}) + + # --------------------------------------------------- + # login (ok pwd) + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + self.assertTrue("Authorization" in response.cookies) + self.assertTrue("Authorization" in s.cookies.get_dict()) + self.assertTrue(s.cookies.get_dict()["Authorization"]) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "TestUser"}) + self.assertTrue("Authorization" in s.cookies.get_dict()) + self.assertTrue(s.cookies.get_dict()["Authorization"]) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TestUser") + self.assertTrue("Authorization" in s.cookies.get_dict()) + self.assertTrue(s.cookies.get_dict()["Authorization"]) + + # --------------------------------------------------- + # login (wrong pwd, after success) + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "abc"}, + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + self.assertDictEqual(s.cookies.get_dict(), {}) + + finally: + proc.terminate() + s.close() + def test_access_resourceACL(self): ip, port = find_free_port() proc = Process( @@ -170,7 +520,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): # try unauthenticated write (to field) response = s.put(f"http://{ip}:{port}/test_resourceACL/test_field", json="TEST SET VALUE") - self.assertEqual(response.status_code, 500) + self.assertEqual(response.status_code, 403) # check not modified response = s.get( @@ -181,7 +531,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): # try unauthenticated write (to resource) response = s.put(f"http://{ip}:{port}/test_resourceACL", json={"test_field": "TEST SET VALUE"}) - self.assertEqual(response.status_code, 500) + self.assertEqual(response.status_code, 403) # check not modified response = s.get( @@ -247,7 +597,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): # try unauthenticated write (to field) response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") - self.assertEqual(response.status_code, 500) + self.assertEqual(response.status_code, 403) # check not modified response = s.get( @@ -258,7 +608,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): # try unauthenticated write (to resource) response = s.put(f"http://{ip}:{port}/test_resource", json={"test_field": "TEST SET VALUE"}) - self.assertEqual(response.status_code, 500) + self.assertEqual(response.status_code, 403) # check not modified response = s.get( diff --git a/test/test_rest_resource.py b/test/test_rest_resource.py index f3b7c9d..2fd3d97 100644 --- a/test/test_rest_resource.py +++ b/test/test_rest_resource.py @@ -14,6 +14,7 @@ print(__name__) print(__package__) from src.pyrestresource import ( + RestResourceHandlerException_Forbiden, register_rest_rootpoint, RestResourceBase, rsrc_verb, @@ -268,11 +269,11 @@ class Test_RestAPI_GET(unittest.TestCase): self.assertEqual(result.get_result(), '"chacha"') def test_get_dict_user_element__nested_value__forbiden(self): - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002/secret", rsrc_verb.GET) def test_get_dict_user_element__nested_value__forbiden2(self): - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request( "/users/8da57a3c-661f-11ee-8c99-0242ac120002/secret?API_nested=True", rsrc_verb.GET, @@ -302,7 +303,7 @@ class Test_RestAPI_PUT(unittest.TestCase): self.assertEqual(result.get_result(), '"chacha2"') def test_put_user_nested_value__forbiden(self): - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request( "/users/8da57a3c-661f-11ee-8c99-0242ac120002/uuid", rsrc_verb.PUT, diff --git a/test/test_rest_resource_plugins.py b/test/test_rest_resource_plugins.py index e797b13..358ee55 100644 --- a/test/test_rest_resource_plugins.py +++ b/test/test_rest_resource_plugins.py @@ -35,7 +35,6 @@ def init_classes(): class ResourcePlugin_Info(ResourcePlugin_RestResourceBase_default): def handle_resource_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get: - print("HOOK GET !!") return Info_get(version="65.45", api_version="98.321") class Info_get(RestResourceBase): @@ -95,9 +94,10 @@ class Test_RestAPI_Plugin_PUT(unittest.TestCase): self.testapp.process_request("/info_put/version", rsrc_verb.PUT, '"1.5.6"') result = self.testapp.process_request("/info_put", rsrc_verb.GET) - print(result.get_result()) + self.assertEqual(result.get_result(), '{"version": "42", "api_version": "0.0.2"}') + result = self.testapp.process_request("/info_put/version", rsrc_verb.GET) - print(result.get_result()) + self.assertEqual(result.get_result(), '"42"') def test_put_field_version_resourceplugin(self): diff --git a/test/test_rest_webserver.py b/test/test_rest_webserver.py index 9329214..5c86163 100644 --- a/test/test_rest_webserver.py +++ b/test/test_rest_webserver.py @@ -154,15 +154,6 @@ class Test_RestAPI_WebServer(unittest.TestCase): ["9b0381d4-65f6-11ee-8c99-0242ac120002"], ) - # Login in - """ - response = s.post( - f"http://{ip}:{port}/login", - params={"username": "test", "password": "test"}, - ) - self.assertEqual(response.status_code, 200) - """ - # Add a new one (with all values setted) response = s.post( f"http://{ip}:{port}/games", -- 2.47.3 From 4dc7243900f5825178c0a2a69bf5715f929868c8 Mon Sep 17 00:00:00 2001 From: cclecle Date: Sun, 5 Nov 2023 22:17:33 +0000 Subject: [PATCH 11/20] add missing typegard dep --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 04f614b..1bbd549 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ classifiers = [ ] dependencies = [ 'packaging', + 'typegard', 'pydantic>=2.4,<3', 'uvicorn>=0.23' ] -- 2.47.3 From 04ef407a6f021701665dbb1f70c4db732521d30e Mon Sep 17 00:00:00 2001 From: cclecle Date: Sun, 5 Nov 2023 22:21:07 +0000 Subject: [PATCH 12/20] typo fix --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1bbd549..8659f19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ classifiers = [ ] dependencies = [ 'packaging', - 'typegard', + 'typeguard', 'pydantic>=2.4,<3', 'uvicorn>=0.23' ] -- 2.47.3 From d58173f07b7ad031622ec9e2338d91d15cdbf415 Mon Sep 17 00:00:00 2001 From: cclecle Date: Mon, 6 Nov 2023 01:11:43 +0000 Subject: [PATCH 13/20] fix typing issues --- RUN_quality.launch | 2 +- RUN_types.launch | 17 +++ pyproject.toml | 4 + src/pyrestresource/__init__.py | 28 ++--- src/pyrestresource/helpers.py | 4 +- src/pyrestresource/rest_ACL.py | 2 +- src/pyrestresource/rest_exceptions.py | 4 + src/pyrestresource/rest_login.py | 25 +++-- src/pyrestresource/rest_model.py | 103 ++++++++++++++++++ src/pyrestresource/rest_request.py | 22 +++- src/pyrestresource/rest_resource.py | 29 ++--- src/pyrestresource/rest_resource_handler.py | 43 +++++--- .../rest_resource_handler_walker.py | 43 ++++---- src/pyrestresource/rest_resource_plugin.py | 13 ++- src/pyrestresource/rest_resource_rootpoint.py | 25 +++-- src/pyrestresource/rest_resource_walker.py | 11 +- src/pyrestresource/rest_types.py | 3 +- test/test_ACL.py | 16 +-- test/test_rest_login.py | 12 +- test/test_rest_resource.py | 14 +-- test/test_rest_resource_plugins.py | 16 +-- test/test_rest_resource_walker.py | 4 +- test/test_rest_resource_walker_tree.py | 4 +- test/test_rest_webserver.py | 14 +-- 24 files changed, 316 insertions(+), 142 deletions(-) create mode 100644 RUN_types.launch create mode 100644 src/pyrestresource/rest_model.py diff --git a/RUN_quality.launch b/RUN_quality.launch index 079d5ed..95fbee2 100644 --- a/RUN_quality.launch +++ b/RUN_quality.launch @@ -10,7 +10,7 @@ - + diff --git a/RUN_types.launch b/RUN_types.launch new file mode 100644 index 0000000..747af5d --- /dev/null +++ b/RUN_types.launch @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/pyproject.toml b/pyproject.toml index 8659f19..8e85e4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,10 @@ where = ["src"] "pyrestresource.data" = ["*.*"] "pyrestresource" = ["py.typed"] +# [[tool.mypy.overrides]] +# module = "pydantic.pydantic_core" +# ignore_missing_imports = true + [project.urls] Homepage = "https://chacha.ddns.net/gitea/chacha/pyrestresource" Documentation = "https://chacha.ddns.net/mkdocs-web/chacha/pyrestresource/master/latest/" diff --git a/src/pyrestresource/__init__.py b/src/pyrestresource/__init__.py index 08402e3..6e63320 100644 --- a/src/pyrestresource/__init__.py +++ b/src/pyrestresource/__init__.py @@ -18,22 +18,9 @@ from typing import TYPE_CHECKING from .__metadata__ import __version__, __Summuary__, __Name__ - -from .rest_resource import RestResourceBase +from .rest_model import RestField from .rest_resource_rootpoint import register_rest_rootpoint - from .rest_types import rsrc_verb, T_SupportedRESTFields - -if TYPE_CHECKING: - from .rest_types import ( - T_ListIndex, - T_ListSize, - T_DictKey, - T_T_DictKey, - T_DictValues, - T_T_DictValues, - ) - from .rest_request_opt import ( RestRequestParams_POST, RestRequestParams_DELETE, @@ -45,18 +32,17 @@ from .rest_request_opt import ( RestRequestParams_Dict_DELETE, RestRequestParams_Dict_GET, ) - from .rest_resource_plugin import ( ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, ResourcePlugin_dict_default, ) from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule +from .rest_resource import RestResourceBase from .rest_login import ( RestResourceBaseLogin, UserLogin, ) - from .rest_exceptions import ( RestResourceException, RestResourceLoginException, @@ -67,3 +53,13 @@ from .rest_exceptions import ( RestResourcePluginException_InvalidPluginSignature, RestResourceHandlerException_Forbiden, ) + +if TYPE_CHECKING: + from .rest_types import ( + T_ListIndex, + T_ListSize, + T_DictKey, + T_T_DictKey, + T_DictValues, + T_T_DictValues, + ) diff --git a/src/pyrestresource/helpers.py b/src/pyrestresource/helpers.py index 2d2723e..8c569ea 100644 --- a/src/pyrestresource/helpers.py +++ b/src/pyrestresource/helpers.py @@ -19,8 +19,8 @@ class _JSONEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, o) -def parse_dict_cookies(cookies: str) -> dict[str, str]: - result = {} +def parse_dict_cookies(cookies: str) -> dict[str, str | None]: + result: dict[str, str | None] = {} for item in cookies.split(";"): item = item.strip() if not item: diff --git a/src/pyrestresource/rest_ACL.py b/src/pyrestresource/rest_ACL.py index e7158e4..bde356d 100644 --- a/src/pyrestresource/rest_ACL.py +++ b/src/pyrestresource/rest_ACL.py @@ -22,7 +22,7 @@ class ACL_target_user(ACL_target): return cls(name=user_login.username) -class ACL_target_user_Annonymous(ACL_target): +class ACL_target_user_Annonymous(ACL_target_user): name: str = "__ANNONYMOUS__" diff --git a/src/pyrestresource/rest_exceptions.py b/src/pyrestresource/rest_exceptions.py index c779dac..64d582c 100644 --- a/src/pyrestresource/rest_exceptions.py +++ b/src/pyrestresource/rest_exceptions.py @@ -2,6 +2,10 @@ class RestResourceException(Exception): pass +class RestResourceConfigException(RestResourceException): + pass + + class RestResourceModelException(RestResourceException): pass diff --git a/src/pyrestresource/rest_login.py b/src/pyrestresource/rest_login.py index f91de8c..4857d7f 100644 --- a/src/pyrestresource/rest_login.py +++ b/src/pyrestresource/rest_login.py @@ -16,10 +16,11 @@ from typing import Optional, ClassVar, TYPE_CHECKING from secrets import token_hex, compare_digest from datetime import datetime, timedelta -from pydantic import BaseModel, Field +from pydantic import BaseModel from .rest_types import rsrc_verb from .rest_resource import RestResourceBase +from .rest_model import RestField from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule, ACL_target_user from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default from .rest_exceptions import ( @@ -30,7 +31,8 @@ from .rest_exceptions import ( ) if TYPE_CHECKING is True: - from .rest_request import RestRequest, RestRequestParams_GET + from .rest_request import RestRequest + from .rest_request_opt import RestRequestParams_RestResourceBase_PUT, RestRequestParams_RestResourceBase_GET class UserLogin(BaseModel): @@ -41,24 +43,26 @@ class UserLogin(BaseModel): class UserSession(BaseModel): last_update: datetime user_login: UserLogin - client: Optional[tuple[str, int]] + client: tuple[str, int] | tuple[()] | None class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): ar_UserLogin: list[UserLogin] = [] - def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login: - return Login(username=self.get_user_login()) + def handle_resource_get(self, resource: Login, params: RestRequestParams_RestResourceBase_GET) -> Login: + return Login(username=self.get_user_login(), secret=None) - def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login: + def handle_resource_put(self, resource: Login, params: RestRequestParams_RestResourceBase_PUT) -> Login: + if resource.username is None or resource.secret is None: + raise RestResourceLoginException_InvalidCredentials() token = self.user_login(resource.username, resource.secret) self.set_resp_cookie_value("Authorization", f"Bearer {token}") return resource class Login(RestResourceBase): - username: Optional[str] = Field(None) - secret: Optional[str] = Field( + username: Optional[str] = RestField(None) + secret: Optional[str] = RestField( None, exclude=True, ACL=[ @@ -73,7 +77,7 @@ class RestResourceBaseLogin(RestResourceBase): _ar_user_session: dict[str, UserSession] = {} _max_session_inactive: ClassVar[timedelta] = timedelta(minutes=20) _max_session_time: ClassVar[timedelta] = timedelta(hours=12) - login: Login = Field(default=Login(), plugin=ResourcePlugin_Login) + login: Login = RestField(default=Login(), plugin=ResourcePlugin_Login) def get_new_cookie_expiration_date(self) -> datetime: return datetime.now() + self._max_session_time @@ -120,8 +124,7 @@ class RestResourceBaseLogin(RestResourceBase): pass pass - if already_failed: - raise RestResourceLoginException_InvalidCredentials() + raise RestResourceLoginException_InvalidCredentials() def _register_user_session(self, user_login: UserLogin, request: RestRequest) -> str: token = token_hex(16) diff --git a/src/pyrestresource/rest_model.py b/src/pyrestresource/rest_model.py new file mode 100644 index 0000000..3f5051b --- /dev/null +++ b/src/pyrestresource/rest_model.py @@ -0,0 +1,103 @@ +from __future__ import annotations +from typing import ( + Any, + Literal, + Callable, + Optional, + TYPE_CHECKING, +) + +from pydantic.fields import Field, _Unset, PydanticUndefined + +from .rest_exceptions import RestResourceModelException + +if TYPE_CHECKING is True: + from .rest_ACL import ACL_record + from .rest_resource_plugin import ResourcePlugin + from typing import Unpack + from pydantic.fields import _EmptyKwargs, AliasPath, AliasChoices + + +def RestField( + default: Any = PydanticUndefined, + *, + default_factory: Callable[[], Any] | None = _Unset, + alias: str | None = _Unset, + alias_priority: int | None = _Unset, + validation_alias: str | AliasPath | AliasChoices | None = _Unset, + serialization_alias: str | None = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + examples: list[Any] | None = _Unset, + exclude: bool | None = _Unset, + discriminator: str | None = _Unset, + json_schema_extra: dict[str, Any] | Callable[[dict[str, Any]], None] | None = _Unset, + frozen: bool | None = _Unset, + validate_default: bool | None = _Unset, + repr: bool = _Unset, + init_var: bool | None = _Unset, + kw_only: bool | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + union_mode: Literal["smart", "left_to_right"] = _Unset, + ACL: Optional[list[ACL_record]] = _Unset, + plugin: Optional[type[ResourcePlugin]] = _Unset, + **extra: Unpack[_EmptyKwargs], +) -> Any: + if not json_schema_extra or json_schema_extra is _Unset: + if extra: + json_schema_extra = extra # type: ignore + else: + json_schema_extra = {} + + if ACL is not _Unset: + json_schema_extra["ACL"] = ACL + + if plugin is not _Unset: + json_schema_extra["plugin"] = plugin + else: + raise RestResourceModelException("json_schema_extra must not be set") + + return Field( + default, + default_factory=default_factory, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + examples=examples, + exclude=exclude, + discriminator=discriminator, + json_schema_extra=json_schema_extra, + frozen=frozen, + validate_default=validate_default, + repr=repr, + init_var=init_var, + kw_only=kw_only, + pattern=pattern, + strict=strict, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_length=min_length, + max_length=max_length, + union_mode=union_mode, + **extra, + ) diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index e7dd54c..a774525 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -13,6 +13,7 @@ from urllib.parse import urlparse, parse_qs from pydantic import BaseModel, Field from typeguard import check_type +from .rest_login import RestResourceBaseLogin from .rest_types import rsrc_verb, T_AllSupportedFields from .rest_request_opt import ( RestRequestParams_POST, @@ -27,6 +28,12 @@ from .rest_request_opt import ( ) from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group from .helpers import parse_dict_cookies +from .rest_exceptions import ( + RestResourceHandlerException_MethodNotAllowed, + RestResourceHandlerException_BadRequest, + RestResourceException, + RestResourceConfigException, +) if TYPE_CHECKING is True: from typing import Optional @@ -121,7 +128,7 @@ class RestRequest(Generic[_T_RestRequestParams]): self.verb: rsrc_verb self.data: dict self._raw_headers: list[Any] = [] - self._client: tuple[str, int] = () + self._client: tuple[str, int] | tuple[()] = () self.headers: dict[str, None | str | dict[str, None | str]] = {"host": None, "cookie": {}} self._saved_url_params: dict self.ReqParams: _T_RestRequestParams = type_request_params() @@ -179,7 +186,7 @@ class RestRequest(Generic[_T_RestRequestParams]): def set_client(self, client: tuple[str, int]) -> None: self._client = client - def get_client(self) -> tuple[str, int]: + def get_client(self) -> tuple[str, int] | tuple[()]: return self._client def set_headers(self, headers: list[Any]) -> None: @@ -193,11 +200,18 @@ class RestRequest(Generic[_T_RestRequestParams]): self.headers["cookie"] = parse_dict_cookies(elem[1].decode("utf-8")) def get_cookie(self, key: str) -> str | None: + if self.headers["cookie"] is None: + return None if key not in self.headers["cookie"]: return None - return self.headers["cookie"][key] + if isinstance(self.headers["cookie"], dict): + return self.headers["cookie"][key] + else: + return None def set_resp_cookie_value(self, key: str, value: str) -> None: + if not isinstance(self.root_resource, RestResourceBaseLogin): + raise RestResourceConfigException("root_resource must be RestResourceBaseLogin to use user_login") self.outgoing_cookie[ key ] = f"{value}; expires={self.root_resource.get_new_cookie_expiration_date().strftime('%a, %d %b %Y %H:%M:%S GMT')}; path=/; HttpOnly" @@ -205,7 +219,7 @@ class RestRequest(Generic[_T_RestRequestParams]): def reset_resp_cookie(self, key: str) -> None: self.outgoing_cookie[key] = "null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT" - def get_host(self) -> str: + def get_host(self) -> str | dict[str, str | None] | None: return self.headers["host"] def set_result(self, result: str): diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index b7c00c0..aca8879 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -12,8 +12,8 @@ import pprint from pydantic import BaseModel +from .rest_types import rsrc_verb from .helpers import _JSONEncoder, forward_exception -from .rest_types import rsrc_verb, _T_SupportedRESTFields from .rest_ACL import ( ACL_record, @@ -23,7 +23,6 @@ from .rest_ACL import ( ACL_rule, ) -from .rest_request import RestRequest from .rest_exceptions import ( RestResourceLoginException_InvalidSession, RestResourceLoginException_SessionTimeout, @@ -37,6 +36,9 @@ from .rest_exceptions import ( ) if TYPE_CHECKING is True: + from .rest_request import RestRequest + from .rest_types import T_SupportedRESTFields + from .rest_resource_plugin import ResourcePlugin from .rest_types import ( T_T_DictKey, T_T_DictValues, @@ -44,7 +46,6 @@ if TYPE_CHECKING is True: class RestResourceBase(ABC, BaseModel, validate_assignment=True): - # _resp_cookies: ClassVar[dict[str, str]] = {} _dict_key_type_: ClassVar[dict[str, T_T_DictKey]] = {} _dict_value_type_: ClassVar[dict[str, T_T_DictValues]] = {} _model_dump_excluded_: ClassVar[dict[str, bool]] = {} @@ -52,13 +53,13 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): _plugins_: ClassVar[ dict[ str, - list[ACL_record], + type[ResourcePlugin], ] ] = {} _ACL_record_: ClassVar[ dict[ str, - ACL_record, + list[ACL_record], ] ] = {} @@ -92,15 +93,16 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): """Check ACL on requested field access""" self._check_acl(request.user, request.groups, request.get_verb(), request.get_resource_origin(req_index), False) - def check_acl_self(self, request: RestRequest, new_data: Optional[dict[str, _T_SupportedRESTFields]]) -> None: + def check_acl_self(self, request: RestRequest, new_data: Optional[dict[str, T_SupportedRESTFields]]) -> None: """Check ACL on requested field operation (involving checking sub-fields)""" if request.get_verb() is rsrc_verb.GET: for key in self.model_fields.keys(): self._check_acl(request.user, request.groups, rsrc_verb.GET, key) elif request.get_verb() is rsrc_verb.PUT: - for key in new_data.keys(): - if key in self.model_fields: - self._check_acl(request.user, request.groups, rsrc_verb.PUT, key) + if new_data is not None: + for key in new_data.keys(): + if key in self.model_fields: + self._check_acl(request.user, request.groups, rsrc_verb.PUT, key) else: raise RestResourceException("Incompatible verb") @@ -122,7 +124,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): return body - async def __call__(self, scope, receive, send): + async def __call__(self, scope, receive, send) -> None: assert scope["type"] == "http" method = scope["method"] @@ -148,7 +150,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): assert request is not None - header_resp = { + header_resp: dict[str, Any] = { "type": "http.response.start", "status": request.get_status(), "headers": [ @@ -162,8 +164,9 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): await send(header_resp) body = None - if request.get_result(): - body = request.get_result().encode("utf-8") + result = request.get_result() + if result: + body = result.encode("utf-8") await send( { diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index 3197c53..6435cbf 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -4,6 +4,7 @@ from typing import Optional, cast, TypeVar, Generic, Self, TYPE_CHECKING import abc from .rest_types import ( + NoneType, rsrc_verb, T_SupportedRESTFields, T_DictKey, @@ -15,6 +16,7 @@ from .rest_resource import RestResourceBase from .rest_request import RequestFactory from .rest_resource_plugin import ( ResourcePlugin_field, + ResourcePlugin_dict, ResourcePlugin_RestResourceBase, ) from .rest_request_opt import ( @@ -100,8 +102,8 @@ class ResourceHandler( elif None in [url, verb]: raise RestResourceHandlerException("if req not set, url,verb must be setted") else: - if url is None or verb is None: - raise RestResourceHandlerException("url and verb must be set") + assert url is not None and verb is not None + assert isinstance(resource, RestResourceBase) if data is None: data = {} self.req = self._request_factory.get_RestRequest(resource, url, verb, data, query_string) @@ -317,11 +319,16 @@ class ResourceHandler_dict( assert self.prev_handler is not None dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] + dict_value_type: T_T_DictValues = cast(RestResourceBase, self.prev_handler.resource)._dict_value_type_[ self.req.get_resource_origin(1) ] - _obj = dict_value_type(**self.req.get_data()) + _obj: T_DictValues + if not issubclass(dict_value_type, NoneType): # type: ignore # => mypy bug with type[None] + _obj = dict_value_type(**self.req.get_data()) # type: ignore # => mypy bug with type[None] + else: + _obj = None _dict: dict[T_DictKey, T_DictValues] = cast(dict[T_DictKey, T_DictValues], self.resource) @@ -479,22 +486,23 @@ class ResourceHandler_RestResourceBase( # CASE 1: no more item in url_stack => we reached the endpoint (operation) # So we are in a RestResourceBase instance and must return the content + plugin_field: ResourcePlugin_field + plugin_resource: ResourcePlugin_RestResourceBase + if len(self.req.get_url_stack()) == 0: self.resource.check_acl_self(self.req, None) for key, attr in self.resource.model_fields.items(): if key in self.resource._plugins_: if issubclass(self.resource._plugins_[key], ResourcePlugin_field): - plugin_field: ResourcePlugin_field = cast( - ResourcePlugin_field, self.resource._plugins_[key](self.req, self.req.get_root_resource()) - ) + plugin_field = cast(ResourcePlugin_field, self.resource._plugins_[key](self.req, self.req.get_root_resource())) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_field_get(value, params)) elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): - plugin_field: ResourcePlugin_field = cast( + plugin_resource = cast( ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.req.get_root_resource()) ) value = getattr(self.resource, key) - setattr(self.resource, key, plugin_field.handle_resource_get(value, params)) + setattr(self.resource, key, plugin_resource.handle_resource_get(value, params)) # result = RestResourceWalker_Root__handler(self.resource).process() # print(result) @@ -512,18 +520,18 @@ class ResourceHandler_RestResourceBase( key = self.req.get_resource_origin(0) if key in self.resource._plugins_: if issubclass(self.resource._plugins_[key], ResourcePlugin_field): - plugin_rsrc: ResourcePlugin_RestResourceBase = cast( - ResourcePlugin_RestResourceBase, + plugin_field = cast( + ResourcePlugin_field, self.resource._plugins_[key](self.req, self.req.get_root_resource()), ) - value = plugin_rsrc.handle_field_get(value, params) + value = plugin_field.handle_field_get(value, params) elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): - plugin_rsrc: ResourcePlugin_RestResourceBase = cast( + plugin_resource = cast( ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.req.get_root_resource()), ) - value = plugin_rsrc.handle_resource_get(value, params) + value = plugin_resource.handle_resource_get(value, params) return value @@ -551,18 +559,19 @@ class ResourceHandler_RestResourceBase( # applying plugins (from parent element) if self.prev_handler is not None: + # element is within a dict if ( - isinstance(self.prev_handler.resource, dict) # element is within a dict + isinstance(self.prev_handler.resource, dict) and self.prev_handler.prev_handler is not None and isinstance(self.prev_handler.prev_handler.resource, RestResourceBase) ): key = self.req.get_resource_origin(2) if key in self.prev_handler.prev_handler.resource._plugins_: - plugin_rsrc: ResourcePlugin_RestResourceBase = cast( - ResourcePlugin_RestResourceBase, + plugin_dict: ResourcePlugin_dict = cast( + ResourcePlugin_dict, self.prev_handler.prev_handler.resource._plugins_[key](self.req, self.req.get_root_resource()), ) - _new_resrc = plugin_rsrc.handle_dict_elem_put(_new_resrc, params) + _new_resrc = plugin_dict.handle_dict_elem_put(_new_resrc, params) # element is within a RestResourceBase elif isinstance(self.prev_handler.resource, RestResourceBase): key = self.req.get_resource_origin(1) diff --git a/src/pyrestresource/rest_resource_handler_walker.py b/src/pyrestresource/rest_resource_handler_walker.py index c0f06ac..62f0df9 100644 --- a/src/pyrestresource/rest_resource_handler_walker.py +++ b/src/pyrestresource/rest_resource_handler_walker.py @@ -23,40 +23,43 @@ from .rest_resource_walker import ( ) if TYPE_CHECKING is True: - from typing import Optional + from typing import Optional, Any class RestResourceWalkerFutureResult_RestResourceBase_handler(RestResourceWalkerFutureResult[dict]): def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: - print(f"RestResourceWalkerFutureResult_RestResourceBase_handler {result}") - res = {} - res[self.source.resource_name] = dict() - for subres in result: - key = next(iter(subres)) - print(key) - res[self.source.resource_name] = res[self.source.resource_name] | subres + # print(f"RestResourceWalkerFutureResult_RestResourceBase_handler {result}") + res: dict[str, Any] = {} + res[self.source.resource_name] = {} + if result: + for subres in result: + key = next(iter(subres)) + print(key) + res[self.source.resource_name] = res[self.source.resource_name] | subres return res class RestResourceWalkerFutureResult_Dict_handler(RestResourceWalkerFutureResult[dict]): def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: - print(f"RestResourceWalkerFutureResult_Dict_handler {result}") - res = {} - for subres in result: - res = res | subres + # print(f"RestResourceWalkerFutureResult_Dict_handler {result}") + res: dict[str, Any] = {} + if result: + for subres in result: + res = res | subres return res class RestResourceWalkerFutureResult_RestFields_handler(RestResourceWalkerFutureResult[dict]): def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: - print(f"RestResourceWalkerFutureResult_RestFields_handler {result}") - print(self.source.resource) - res = {} - res[self.source.resource_name] = dict() - for subres in result: - key = next(iter(subres)) - print(key) - res[self.source.resource_name] = res[self.source.resource_name] | subres + # print(f"RestResourceWalkerFutureResult_RestFields_handler {result}") + # print(self.source.resource) + res: dict[str, Any] = {} + res[self.source.resource_name] = {} + if result: + for subres in result: + key = next(iter(subres)) + print(key) + res[self.source.resource_name] = res[self.source.resource_name] | subres return res diff --git a/src/pyrestresource/rest_resource_plugin.py b/src/pyrestresource/rest_resource_plugin.py index 6a97876..a110c1e 100644 --- a/src/pyrestresource/rest_resource_plugin.py +++ b/src/pyrestresource/rest_resource_plugin.py @@ -10,10 +10,11 @@ from .rest_types import ( TV_SupportedRESTFields, TV_RestResourceBase, ) -from .rest_request import RestRequest +from .rest_exceptions import RestResourceConfigException if TYPE_CHECKING is True: + from .rest_request import RestRequest from .rest_resource import RestResourceBase from .rest_request_opt import ( RestRequestParams_GET, @@ -31,9 +32,13 @@ if TYPE_CHECKING is True: class ResourcePlugin(ABC): def __init__(self, request: RestRequest, root_resource: RestResourceBase) -> None: self.__request: RestRequest = request - self.__root_resource: RestRequest = root_resource + self.__root_resource: RestResourceBase = root_resource def user_login(self, user_name: str, user_secret: str) -> str: + from .rest_login import RestResourceBaseLogin + + if not isinstance(self.__root_resource, RestResourceBaseLogin): + raise RestResourceConfigException("root_resource must be RestResourceBaseLogin to use user_login") return self.__root_resource.user_login(user_name, user_secret, self.__request) def get_user_login(self) -> str: @@ -46,6 +51,10 @@ class ResourcePlugin(ABC): self.__request.reset_resp_cookie(key) def get_new_cookie_expiration_date(self) -> datetime: + from .rest_login import RestResourceBaseLogin + + if not isinstance(self.__root_resource, RestResourceBaseLogin): + raise RestResourceConfigException("root_resource must be RestResourceBaseLogin to use get_new_cookie_expiration_date") return self.__root_resource.get_new_cookie_expiration_date() def set_resp_status(self, status: int) -> None: diff --git a/src/pyrestresource/rest_resource_rootpoint.py b/src/pyrestresource/rest_resource_rootpoint.py index ac62b58..caec5cd 100644 --- a/src/pyrestresource/rest_resource_rootpoint.py +++ b/src/pyrestresource/rest_resource_rootpoint.py @@ -1,10 +1,12 @@ from __future__ import annotations from typing import ( + cast, get_args, get_origin, TYPE_CHECKING, ) +from pydantic import BaseModel from pydantic.fields import FieldInfo from .rest_resource import RestResourceBase @@ -47,8 +49,13 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): self.parent.annotation._dict_value_type_[self.resource_name] = datatype[1] # pylint: disable=protected-access self.parent.annotation._model_dump_excluded_[self.resource_name] = True # pylint: disable=protected-access - self.resource.exclude = True - self.parent.resource.model_rebuild(force=True) + assert isinstance(self.resource, FieldInfo) + current_resource = cast(FieldInfo, self.resource) + current_resource.exclude = True + + parent_resource = cast(type[RestResourceBase], self.parent.resource) + assert issubclass(parent_resource, RestResourceBase) + parent_resource.model_rebuild(force=True) self.parent.annotation._ACL_record_[self.resource_name] = [] @@ -58,7 +65,7 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): and type(self.resource.json_schema_extra) is dict ): if "plugin" in self.resource.json_schema_extra: - plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"] + plugin_dict: type[ResourcePlugin_dict] = self.resource.json_schema_extra["plugin"] if not issubclass(plugin_dict, ResourcePlugin_dict): raise RestResourcePluginException_InvalidPluginSignature() self.parent.annotation._plugins_[self.resource_name] = plugin_dict @@ -105,7 +112,7 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi ] if "plugin" in self.resource.json_schema_extra: - plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"] + plugin_field: type[ResourcePlugin_field] = self.resource.json_schema_extra["plugin"] if not issubclass(plugin_field, ResourcePlugin_field): raise RestResourcePluginException_InvalidPluginSignature() self.parent.annotation._plugins_[self.resource_name] = plugin_field @@ -134,8 +141,12 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_ # preprocessing types / structure if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): self.parent.annotation._model_dump_excluded_[self.resource_name] = True - self.resource.exclude = True - self.parent.resource.model_rebuild(force=True) + assert isinstance(self.resource, FieldInfo) + current_resource = cast(FieldInfo, self.resource) + current_resource.exclude = True + parent_resource = cast(type[RestResourceBase], self.parent.resource) + assert issubclass(parent_resource, RestResourceBase) + parent_resource.model_rebuild(force=True) self.parent.annotation._ACL_record_[self.resource_name] = [] if ( @@ -144,7 +155,7 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_ and type(self.resource.json_schema_extra) is dict ): if "plugin" in self.resource.json_schema_extra: - plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] + plugin_resource: type[ResourcePlugin_RestResourceBase] = self.resource.json_schema_extra["plugin"] if not issubclass(plugin_resource, ResourcePlugin_RestResourceBase): raise RestResourcePluginException_InvalidPluginSignature() self.parent.annotation._plugins_[self.resource_name] = plugin_resource diff --git a/src/pyrestresource/rest_resource_walker.py b/src/pyrestresource/rest_resource_walker.py index f3d14fe..fc68aa9 100644 --- a/src/pyrestresource/rest_resource_walker.py +++ b/src/pyrestresource/rest_resource_walker.py @@ -53,7 +53,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): resource_name: str, resource: FieldInfo | Type[RestResourceBase], parent: Optional[RestResourceWalker_Sub] = None, - argument: Optional[any] = None, + argument: Optional[Any] = None, ) -> Optional[RestResourceWalker_Sub]: for sub in subs: _is_valid, _anno, _optional = sub.check_type(resource) @@ -70,9 +70,9 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): parent: Optional[RestResourceWalker_Sub] = None, annotation: Optional[type[RestResourceBase]] = None, _optional: Optional[bool] = None, - argument: Optional[any] = None, + argument: Optional[Any] = None, ): - self.argument: any = argument + self.argument: Any = argument self.resource_name: str = resource_name self.resource: FieldInfo | Type[RestResourceBase] = resource self.parent: Optional[RestResourceWalker_Sub] = parent @@ -208,14 +208,14 @@ class RestResourceWalker_Root: ] def __init__(self, resource: RestResourceBase | Type[RestResourceBase]) -> None: - self.subwalker_argument: any = None + self.subwalker_argument: Any = None self.resource: Type[RestResourceBase] if isinstance(resource, RestResourceBase): self.resource = type(resource) else: self.resource = resource - def process(self, argument: Optional[any] = None, deep_limit: Optional[int] = None) -> Optional[TV_RestResourceWalkerFutureResult]: + def process(self, argument: Optional[Any] = None, deep_limit: Optional[int] = None) -> Optional[TV_RestResourceWalkerFutureResult]: current_deep: int = 0 sub_walker_initial: Optional[RestResourceWalker_Sub] = RestResourceWalker_Sub.get( @@ -254,4 +254,3 @@ class RestResourceWalker_Root: return sub_walker_initial.chain_process_future() else: raise RestResourceModelException("Invalid Rootpoint") - return None diff --git a/src/pyrestresource/rest_types.py b/src/pyrestresource/rest_types.py index da34850..8898690 100644 --- a/src/pyrestresource/rest_types.py +++ b/src/pyrestresource/rest_types.py @@ -9,8 +9,7 @@ from uuid import UUID from ipaddress import IPv4Address, IPv4Network if TYPE_CHECKING is True: - pass - + from .rest_resource import RestResourceBase T_Gen_DictKeys: type = type({}.keys()) NoneType = type(None) diff --git a/test/test_ACL.py b/test/test_ACL.py index 7847676..ecf3483 100644 --- a/test/test_ACL.py +++ b/test/test_ACL.py @@ -3,9 +3,9 @@ import unittest from os import chdir from pathlib import Path from typing import Optional -from pydantic import Field from src.pyrestresource import ( + RestField, RestResourceHandlerException_Forbiden, register_rest_rootpoint, RestResourceBase, @@ -30,8 +30,8 @@ chdir(testdir_path.parent.resolve()) # to allow mock-ing, all the tested classes are in a function def init_classes(): class TestResource(RestResourceBase): - username: Optional[str] = Field(None) - secret: Optional[str] = Field( + username: Optional[str] = RestField(None) + secret: Optional[str] = RestField( None, exclude=True, ACL=[ @@ -41,21 +41,21 @@ def init_classes(): ) class TestResource2(RestResourceBase): - version_ro: Optional[str] = Field( + version_ro: Optional[str] = RestField( "1.2.3", ACL=[ ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), ], ) - version: Optional[str] = Field("3.2.1") + version: Optional[str] = RestField("3.2.1") @register_rest_rootpoint class RootApp(RestResourceBase): - resource_with_secret: TestResource = Field(default=TestResource()) - resource_with_secret_ACL: TestResource = Field( + resource_with_secret: TestResource = RestField(default=TestResource()) + resource_with_secret_ACL: TestResource = RestField( default=TestResource(), ACL=[ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY)] ) - resource_ro: TestResource2 = Field(TestResource2()) + resource_ro: TestResource2 = RestField(TestResource2()) # this add the classes to globals to allow using them later on # => this is only for uinit-testing purpose and is not needed in real use diff --git a/test/test_rest_login.py b/test/test_rest_login.py index 9107fbc..ee369b6 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -3,7 +3,6 @@ import unittest from os import chdir from pathlib import Path from typing import Optional, ClassVar -from pydantic import Field from time import sleep import uvicorn import socket @@ -13,6 +12,7 @@ from multiprocessing import Process from requests.adapters import HTTPAdapter from src.pyrestresource import ( + RestField, ACL_target_user, UserLogin, RestResourceBase, @@ -42,24 +42,24 @@ def init_classes(): user_test2 = UserLogin(username="TestUser2", secret="abcdef") class TestResource(RestResourceBase): - test_field: Optional[str] = Field("ORIGIN_VALUE") + test_field: Optional[str] = RestField("ORIGIN_VALUE") class TestResourceACL(RestResourceBase): - test_field: Optional[str] = Field( + test_field: Optional[str] = RestField( "ORIGIN_VALUE", ACL=[ ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test), rule=ACL_rule.ALLOW), ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), ], ) - test_field2: Optional[str] = Field( + test_field2: Optional[str] = RestField( "ORIGIN_VALUE", ACL=[ ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test2), rule=ACL_rule.ALLOW), ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), ], ) - test_field_both: Optional[str] = Field( + test_field_both: Optional[str] = RestField( "ORIGIN_VALUE", ACL=[ ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test), rule=ACL_rule.ALLOW), @@ -71,7 +71,7 @@ def init_classes(): @register_rest_rootpoint class RootApp(RestResourceBaseLogin): _ar_user_login: ClassVar[list[UserLogin]] = [user_test, user_test2] - test_resourceACL: TestResource = Field( + test_resourceACL: TestResource = RestField( TestResource(), ACL=[ ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user(name=user_test.username), rule=ACL_rule.ALLOW), diff --git a/test/test_rest_resource.py b/test/test_rest_resource.py index 2fd3d97..9773b9b 100644 --- a/test/test_rest_resource.py +++ b/test/test_rest_resource.py @@ -4,7 +4,6 @@ from unittest.mock import patch from os import chdir from pathlib import Path from typing import Optional -from pydantic import Field from uuid import UUID, uuid4 from time import time import json @@ -14,6 +13,7 @@ print(__name__) print(__package__) from src.pyrestresource import ( + RestField, RestResourceHandlerException_Forbiden, register_rest_rootpoint, RestResourceBase, @@ -39,19 +39,19 @@ def init_classes(): api_version: str class Patch(RestResourceBase): - uuid: UUID = Field(default_factory=uuid4, primary_key=True) + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) shortname: str name: Optional[str] = None description: Optional[str] = None class Profile(RestResourceBase): - uuid: UUID = Field(default_factory=uuid4, primary_key=True) + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) shortname: str name: Optional[str] = None description: Optional[str] = None class Game(RestResourceBase): - uuid: UUID = Field(default_factory=uuid4, primary_key=True) + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) shortname: str name: Optional[str] = None description: Optional[str] = None @@ -62,12 +62,12 @@ def init_classes(): Patch_2 = Patch(uuid="d385a1d2-65fa-11ee-8c99-0242ac120002", shortname="testPatch2") class User(RestResourceBase): - uuid: UUID = Field( + uuid: UUID = RestField( default_factory=uuid4, primary_key=True, ) name: str - secret: str = Field( + secret: str = RestField( ..., exclude=True, ACL=[ @@ -83,7 +83,7 @@ def init_classes(): ) class Patch2(RestResourceBase): - uuid: UUID = Field(default_factory=uuid4, primary_key=True) + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) shortname: str name: Optional[str] = None description: Optional[str] = None diff --git a/test/test_rest_resource_plugins.py b/test/test_rest_resource_plugins.py index 358ee55..eecc27f 100644 --- a/test/test_rest_resource_plugins.py +++ b/test/test_rest_resource_plugins.py @@ -3,9 +3,9 @@ import unittest from os import chdir from pathlib import Path from typing import Annotated -from pydantic import Field from src.pyrestresource import ( + RestField, register_rest_rootpoint, RestResourceBase, rsrc_verb, @@ -40,27 +40,27 @@ def init_classes(): class Info_get(RestResourceBase): # test plugin injection within annotation # + test plugin on a simple field - version: Annotated[str, Field(plugin=ResourcePlugin_version_get)] + version: Annotated[str, RestField(plugin=ResourcePlugin_version_get)] api_version: str class Info_put(RestResourceBase): # test plugin injection within annotation # + test plugin on a simple field - version: Annotated[str, Field(plugin=ResourcePlugin_version_put)] + version: Annotated[str, RestField(plugin=ResourcePlugin_version_put)] api_version: str @register_rest_rootpoint class RootApp(RestResourceBase): # test plugin injection within Field value # + test plugin on a RestResourceBase field - info: Info_get = Field( + info: Info_get = RestField( default=Info_get(version="0.0.1", api_version="0.0.2"), plugin=ResourcePlugin_Info, ) - info_put: Info_put = Field( + info_put: Info_put = RestField( default=Info_put(version="0.0.1", api_version="0.0.2"), ) - info2: Info_get = Field(default=Info_get(version="0.0.2", api_version="0.0.3")) + info2: Info_get = RestField(default=Info_get(version="0.0.2", api_version="0.0.3")) # this add the classes to globals to allow using them later on # => this is only for uinit-testing purpose and is not needed in real use @@ -75,11 +75,11 @@ def init_bad_plugin1(): ... class TestResource(RestResourceBase): - tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] + tetvaluestr: Annotated[str, RestField(plugin=ResourcePlugin_TestResource)] @register_rest_rootpoint class RootApp2(RestResourceBase): - test: TestResource = Field(default=TestResource(tetvaluestr="testvalue")) + test: TestResource = RestField(default=TestResource(tetvaluestr="testvalue")) RootApp2() diff --git a/test/test_rest_resource_walker.py b/test/test_rest_resource_walker.py index ecaefd2..5fbc619 100644 --- a/test/test_rest_resource_walker.py +++ b/test/test_rest_resource_walker.py @@ -5,7 +5,6 @@ from typing import Optional from os import chdir from pathlib import Path -from pydantic import Field from io import StringIO from contextlib import redirect_stdout @@ -13,6 +12,7 @@ print(__name__) print(__package__) from src.pyrestresource import ( + RestField, RestResourceBase, ) @@ -80,7 +80,7 @@ def init_classes(): last_name: str class RootApp(RestResourceBase): - info: Info = Field(default=Info(version="0.0.1", api_version="0.0.2")) + info: Info = RestField(default=Info(version="0.0.1", api_version="0.0.2")) info2: Info = Info(version="0.0.2", api_version="0.0.3") peoples: dict[str, People] = { "john": People(last_name="Doe"), diff --git a/test/test_rest_resource_walker_tree.py b/test/test_rest_resource_walker_tree.py index afaeb40..fb2d400 100644 --- a/test/test_rest_resource_walker_tree.py +++ b/test/test_rest_resource_walker_tree.py @@ -5,13 +5,13 @@ from typing import Optional from os import chdir from pathlib import Path -from pydantic import Field print(__name__) print(__package__) from src.pyrestresource import ( + RestField, RestResourceBase, ) @@ -80,7 +80,7 @@ def init_classes(): last_name: str class RootApp(RestResourceBase): - info: Info = Field(default=Info(version="0.0.1", api_version="0.0.2")) + info: Info = RestField(default=Info(version="0.0.1", api_version="0.0.2")) info2: Info = Info(version="0.0.2", api_version="0.0.3") peoples: dict[str, People] = { "john": People(last_name="Doe"), diff --git a/test/test_rest_webserver.py b/test/test_rest_webserver.py index 5c86163..63502f1 100644 --- a/test/test_rest_webserver.py +++ b/test/test_rest_webserver.py @@ -4,7 +4,6 @@ from unittest.mock import patch from os import chdir from pathlib import Path from typing import Optional -from pydantic import Field from uuid import UUID, uuid4 from time import time, sleep import json @@ -19,6 +18,7 @@ print(__name__) print(__package__) from src.pyrestresource import ( + RestField, register_rest_rootpoint, RestResourceBase, rsrc_verb, @@ -40,19 +40,19 @@ def init_classes(): api_version: str class Patch(RestResourceBase): - uuid: UUID = Field(default_factory=uuid4, primary_key=True) + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) shortname: str name: Optional[str] = None description: Optional[str] = None class Profile(RestResourceBase): - uuid: UUID = Field(default_factory=uuid4, primary_key=True) + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) shortname: str name: Optional[str] = None description: Optional[str] = None class Game(RestResourceBase): - uuid: UUID = Field(default_factory=uuid4, primary_key=True) + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) shortname: str name: Optional[str] = None description: Optional[str] = None @@ -63,9 +63,9 @@ def init_classes(): Patch_2 = Patch(uuid="d385a1d2-65fa-11ee-8c99-0242ac120002", shortname="testPatch2") class User(RestResourceBase): - uuid: UUID = Field(default_factory=uuid4, primary_key=True) + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) name: str - secret: str = Field(..., exclude=True) + secret: str = RestField(..., exclude=True) User1 = User( uuid="8da57a3c-661f-11ee-8c99-0242ac120002", @@ -74,7 +74,7 @@ def init_classes(): ) class Patch2(RestResourceBase): - uuid: UUID = Field(default_factory=uuid4, primary_key=True) + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) shortname: str name: Optional[str] = None description: Optional[str] = None -- 2.47.3 From 4af812cf808af65425a444e2ba0960d231c8d1da Mon Sep 17 00:00:00 2001 From: cclecle Date: Mon, 6 Nov 2023 09:27:53 +0000 Subject: [PATCH 14/20] enable multiprocessing on coverage --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 8e85e4c..bf524b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,9 @@ where = ["src"] # module = "pydantic.pydantic_core" # ignore_missing_imports = true +[tool.coverage.run] +concurrency = "multiprocessing" + [project.urls] Homepage = "https://chacha.ddns.net/gitea/chacha/pyrestresource" Documentation = "https://chacha.ddns.net/mkdocs-web/chacha/pyrestresource/master/latest/" -- 2.47.3 From d0e146ac7640f43bd15ad454e8d7f0fc4a38dc5a Mon Sep 17 00:00:00 2001 From: cclecle Date: Mon, 6 Nov 2023 09:33:55 +0000 Subject: [PATCH 15/20] try to fix toml / coverage --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bf524b6..a26d045 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,10 @@ where = ["src"] # ignore_missing_imports = true [tool.coverage.run] -concurrency = "multiprocessing" +concurrency = [ + 'thread', + 'multiprocessing' +] [project.urls] Homepage = "https://chacha.ddns.net/gitea/chacha/pyrestresource" -- 2.47.3 From 0ec875e497e7746d50c66fd98f71f1e6e0a390d1 Mon Sep 17 00:00:00 2001 From: cclecle Date: Mon, 6 Nov 2023 10:35:05 +0000 Subject: [PATCH 16/20] add parallel coverage option --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a26d045..05a9490 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,10 +52,11 @@ where = ["src"] "pyrestresource" = ["py.typed"] # [[tool.mypy.overrides]] -# module = "pydantic.pydantic_core" +# module = "" # ignore_missing_imports = true [tool.coverage.run] +parallel = true concurrency = [ 'thread', 'multiprocessing' -- 2.47.3 From 3afebdba3308778efd4bd531f7b166cba5643140 Mon Sep 17 00:00:00 2001 From: cclecle Date: Mon, 6 Nov 2023 11:25:49 +0000 Subject: [PATCH 17/20] temporary disable PERF tests --- test/test_rest_resource.py | 2 +- test/test_rest_webserver.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_rest_resource.py b/test/test_rest_resource.py index 9773b9b..e169ce0 100644 --- a/test/test_rest_resource.py +++ b/test/test_rest_resource.py @@ -484,7 +484,7 @@ class Test_RestAPI_PERFO(unittest.TestCase): init_classes() self.testapp = RootApp() - # @unittest.skip + @unittest.skip def test_perf_dict(self): print(f"LIB INTERNAL PERF TEST") n_loop = 10000 diff --git a/test/test_rest_webserver.py b/test/test_rest_webserver.py index 63502f1..ac1ab29 100644 --- a/test/test_rest_webserver.py +++ b/test/test_rest_webserver.py @@ -13,6 +13,7 @@ import requests from contextlib import closing from multiprocessing import Process from requests.adapters import HTTPAdapter +import coverage print(__name__) print(__package__) @@ -121,6 +122,7 @@ def find_free_port(): def launch_server(ip, port): + coverage.process_startup() init_classes() uvicorn.run(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True) @@ -275,7 +277,7 @@ class Test_RestAPI_WebServer(unittest.TestCase): proc.terminate() s.close() - # @unittest.skip + @unittest.skip def test_perf_dict(self): print(f"SOCKET PERF TEST") n_loop = 10000 -- 2.47.3 From 9b3e847908f2804f99674017dfed9930e60eb8da Mon Sep 17 00:00:00 2001 From: cclecle Date: Mon, 6 Nov 2023 13:50:34 +0000 Subject: [PATCH 18/20] use threaded uvicorn during test --- README.md | 16 ++++--- pyproject.toml | 7 ++- test/ThreadedUvicorn.py | 23 ++++++++++ test/__init__.py | 2 + test/test_rest_login.py | 92 +++++++++++++------------------------ test/test_rest_webserver.py | 36 +++++---------- 6 files changed, 84 insertions(+), 92 deletions(-) create mode 100644 test/ThreadedUvicorn.py diff --git a/README.md b/README.md index 36ee062..192803e 100644 --- a/README.md +++ b/README.md @@ -12,21 +12,23 @@ A RESTful API library built on top of pydantic & uvicorn to make service API from a data model. -/!\ early in-progress project for internal use ATM. +/!\\ early in-progress project for internal use ATM. Feel free to contribute. -Features: -- use annotation +Features (available): +- type annotation used - support containers (dict) - support plugins (for hook and biding) -- user authentification (WIP) -- ACL (WIP) -- python internal model instance (with possible serialization/auto-save on-disk) +- user auth +- ACL - daemon mode +Features(planned): +- group support +- python internal model instance (with possible serialization/auto-save on-disk) + Limitations: - no nested reads / writes -- weak unitest (atm) Checkout [Latest Documentation](https://chacha.ddns.net/mkdocs-web/chacha/pyrestresource/master/latest/). \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 05a9490..8866cff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,10 +56,13 @@ where = ["src"] # ignore_missing_imports = true [tool.coverage.run] +cover_pylib = false +branch = true +data_file="helpers-results/cl_unit_test_raw_coverage/.coverage" +# debug = ["config","multiproc","process"] parallel = true concurrency = [ - 'thread', - 'multiprocessing' + 'thread' ] [project.urls] diff --git a/test/ThreadedUvicorn.py b/test/ThreadedUvicorn.py new file mode 100644 index 0000000..b47d625 --- /dev/null +++ b/test/ThreadedUvicorn.py @@ -0,0 +1,23 @@ +from uvicorn import Config, Server +from threading import Thread +import asyncio + + +class ThreadedUvicorn: + def __init__(self, config: Config): + self.server = Server(config) + self.thread = Thread(daemon=True, target=self.server.run) + + def start(self): + self.thread.start() + asyncio.run(self.wait_for_started()) + + async def wait_for_started(self): + while not self.server.started: + await asyncio.sleep(0.1) + + def stop(self): + if self.thread.is_alive(): + self.server.should_exit = True + while self.thread.is_alive(): + continue diff --git a/test/__init__.py b/test/__init__.py index 0aef653..006fd7e 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -5,3 +5,5 @@ # # You should have received a copy of the license along with this # work. If not, see . + +from .ThreadedUvicorn import ThreadedUvicorn diff --git a/test/test_rest_login.py b/test/test_rest_login.py index ee369b6..8639cfc 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -32,6 +32,8 @@ from src.pyrestresource import ( ) +from test import ThreadedUvicorn + testdir_path = Path(__file__).parent.resolve() chdir(testdir_path.parent.resolve()) @@ -93,26 +95,18 @@ def find_free_port(): return "localhost", s.getsockname()[1] -def launch_server(ip, port): - init_classes() - uvicorn.run(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True) - - class Test_RestAPI_LOGIN_Web(unittest.TestCase): def setUp(self) -> None: chdir(testdir_path.parent.resolve()) def test_login_two_users(self): ip, port = find_free_port() - proc = Process( - target=launch_server, - args=( - ip, - port, - ), - ) - proc.start() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() sleep(1) + s = requests.Session() s.mount("http://", HTTPAdapter(max_retries=0)) @@ -206,19 +200,15 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): self.assertEqual(response.json(), "A TEST SET VALUE 2") finally: - proc.terminate() s.close() + server.stop() def test_login(self): ip, port = find_free_port() - proc = Process( - target=launch_server, - args=( - ip, - port, - ), - ) - proc.start() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() sleep(1) s = requests.Session() s.mount("http://", HTTPAdapter(max_retries=0)) @@ -260,19 +250,15 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): self.assertEqual(response.json(), "TestUser") finally: - proc.terminate() s.close() + server.stop() def test_change_host(self): ip, port = find_free_port() - proc = Process( - target=launch_server, - args=( - ip, - port, - ), - ) - proc.start() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() sleep(1) s1 = requests.Session() s1.mount("http://", HTTPAdapter(max_retries=0)) @@ -378,20 +364,16 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): self.assertEqual(response.json(), "__ANNONYMOUS__") finally: - proc.terminate() s1.close() s2.close() + server.stop() def test_login_wrong_pwd(self): ip, port = find_free_port() - proc = Process( - target=launch_server, - args=( - ip, - port, - ), - ) - proc.start() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() sleep(1) s = requests.Session() s.mount("http://", HTTPAdapter(max_retries=0)) @@ -493,19 +475,15 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): self.assertDictEqual(s.cookies.get_dict(), {}) finally: - proc.terminate() s.close() + server.stop() def test_access_resourceACL(self): ip, port = find_free_port() - proc = Process( - target=launch_server, - args=( - ip, - port, - ), - ) - proc.start() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() sleep(1) s = requests.Session() s.mount("http://", HTTPAdapter(max_retries=0)) @@ -570,19 +548,15 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): self.assertEqual(response.json(), "TEST SET VALUE 2") finally: - proc.terminate() s.close() + server.stop() def test_access_fieldACL(self): ip, port = find_free_port() - proc = Process( - target=launch_server, - args=( - ip, - port, - ), - ) - proc.start() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() sleep(1) s = requests.Session() s.mount("http://", HTTPAdapter(max_retries=0)) @@ -647,5 +621,5 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): self.assertEqual(response.json(), "TEST SET VALUE 2") finally: - proc.terminate() s.close() + server.stop() diff --git a/test/test_rest_webserver.py b/test/test_rest_webserver.py index ac1ab29..dbe2af7 100644 --- a/test/test_rest_webserver.py +++ b/test/test_rest_webserver.py @@ -30,6 +30,8 @@ from src.pyrestresource import ( ) from pprint import pprint +from test import ThreadedUvicorn + testdir_path = Path(__file__).parent.resolve() chdir(testdir_path.parent.resolve()) @@ -121,26 +123,16 @@ def find_free_port(): return "localhost", s.getsockname()[1] -def launch_server(ip, port): - coverage.process_startup() - init_classes() - uvicorn.run(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True) - - class Test_RestAPI_WebServer(unittest.TestCase): def setUp(self) -> None: chdir(testdir_path.parent.resolve()) def test_nomal_AllCmd_games(self): ip, port = find_free_port() - proc = Process( - target=launch_server, - args=( - ip, - port, - ), - ) - proc.start() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() sleep(1) s = requests.Session() s.mount("http://", HTTPAdapter(max_retries=0)) @@ -274,8 +266,8 @@ class Test_RestAPI_WebServer(unittest.TestCase): data = response.json() self.assertTrue(len(data) == 0) finally: - proc.terminate() s.close() + server.stop() @unittest.skip def test_perf_dict(self): @@ -283,14 +275,10 @@ class Test_RestAPI_WebServer(unittest.TestCase): n_loop = 10000 ip, port = find_free_port() - proc = Process( - target=launch_server, - args=( - ip, - port, - ), - ) - proc.start() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() sleep(1) s = requests.Session() s.mount("http://", HTTPAdapter(max_retries=0)) @@ -368,5 +356,5 @@ class Test_RestAPI_WebServer(unittest.TestCase): print(f"PUT/GET 2nd level (value) dict: {int(n_loop/(end-start))} Req/s") finally: - proc.terminate() s.close() + server.stop() -- 2.47.3 From 7e13d49febf0ca67b0fd427c9d95c7430c455485 Mon Sep 17 00:00:00 2001 From: cclecle Date: Mon, 6 Nov 2023 13:57:00 +0000 Subject: [PATCH 19/20] remove unused data dir from toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8866cff..10abb4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ include-package-data = true where = ["src"] [tool.setuptools.package-data] -"pyrestresource.data" = ["*.*"] "pyrestresource" = ["py.typed"] # [[tool.mypy.overrides]] -- 2.47.3 From 6311d90a2d17ccd5523eaff93df31501afc5d761 Mon Sep 17 00:00:00 2001 From: cclecle Date: Mon, 6 Nov 2023 14:56:46 +0000 Subject: [PATCH 20/20] update jenkins & toml from project template --- Jenkinsfile | 14 +++++++++++--- pyproject.toml | 12 ++++++------ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index ad653b3..eb6cf6c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -426,9 +426,17 @@ pipeline { } post { always { - dir("gitrepo") { - publishCoverage adapters: [cobertura(mergeToOneReport: true, path: "helpers-results/cl_types_check/cobertura.xml")] - junit 'helpers-results/cl_types_check/junit.xml' + dir("gitrepo") { + //publish coverage + recordCoverage( sourceDirectories: [[path: 'src']], + tools: [[parser: 'COBERTURA', pattern: 'helpers-results/cl_types_check/cobertura.xml']], + id: 'COBERTURA', name: 'COBERTURA Coverage', + sourceCodeRetention: 'EVERY_BUILD',) + + //add type check to junit result set + junit 'helpers-results/cl_types_check/junit.xml' + + //publish html reports files publishHTML([ reportDir: "helpers-results/cl_quality_check", reportFiles: "report.html", diff --git a/pyproject.toml b/pyproject.toml index 10abb4a..22b81f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,12 +70,12 @@ Documentation = "https://chacha.ddns.net/mkdocs-web/chacha/pyrestresource/mast Tracker = "https://chacha.ddns.net/gitea/chacha/pyrestresource/issues" [project.optional-dependencies] -test = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] -coverage-check = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] -complexity-check = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] -quality-check = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] -type-check = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] -doc-gen = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] +test = ["chacha_cicd_helper"] +coverage-check = ["chacha_cicd_helper"] +complexity-check = ["chacha_cicd_helper"] +quality-check = ["chacha_cicd_helper"] +type-check = ["chacha_cicd_helper"] +doc-gen = ["chacha_cicd_helper"] # [project.scripts] # my-script = "my_package.module:function" -- 2.47.3