diff --git a/.project b/.project index dc36649..f4b48b7 100644 --- a/.project +++ b/.project @@ -1,6 +1,6 @@ - {{project_name}} + pyrestresource diff --git a/.settings/org.eclipse.core.resources.prefs b/.settings/org.eclipse.core.resources.prefs index 99f26c0..c89f8d0 100644 --- a/.settings/org.eclipse.core.resources.prefs +++ b/.settings/org.eclipse.core.resources.prefs @@ -1,2 +1,7 @@ eclipse.preferences.version=1 +encoding//src/pyrestresource/__init__.py=utf-8 +encoding//src/pyrestresource/__metadata__.py=utf-8 +encoding//src/pyrestresource/rest_login.py=utf-8 +encoding//src/pyrestresource/rest_resource.py=utf-8 +encoding//src/pyrestresource/rest_resource_handler_walker.py=utf-8 encoding/=UTF-8 diff --git a/Dockerfile b/Dockerfile index 73b5bb6..9c2c266 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,13 +6,9 @@ # You should have received a copy of the license along with this # work. If not, see . -FROM debian:bullseye-slim +FROM debian:bookworm-slim ENV DEBIAN_FRONTEND=noninteractive RUN apt update -RUN apt install -y python3.9 python3-virtualenv python3-pip git python3.9-venv weasyprint - -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --upgrade virtualenv -RUN python3 -m pip install --upgrade setuptools wheel build \ No newline at end of file +RUN apt install -y python3.11 python3-virtualenv python3-pip git python3-venv weasyprint \ No newline at end of file diff --git a/Jenkinsfile b/Jenkinsfile index ad653b3..eb6cf6c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -426,9 +426,17 @@ pipeline { } post { always { - dir("gitrepo") { - publishCoverage adapters: [cobertura(mergeToOneReport: true, path: "helpers-results/cl_types_check/cobertura.xml")] - junit 'helpers-results/cl_types_check/junit.xml' + dir("gitrepo") { + //publish coverage + recordCoverage( sourceDirectories: [[path: 'src']], + tools: [[parser: 'COBERTURA', pattern: 'helpers-results/cl_types_check/cobertura.xml']], + id: 'COBERTURA', name: 'COBERTURA Coverage', + sourceCodeRetention: 'EVERY_BUILD',) + + //add type check to junit result set + junit 'helpers-results/cl_types_check/junit.xml' + + //publish html reports files publishHTML([ reportDir: "helpers-results/cl_quality_check", reportFiles: "report.html", diff --git a/README.md b/README.md index 82c6bd2..192803e 100644 --- a/README.md +++ b/README.md @@ -8,46 +8,27 @@ ![](docs-static/Library.jpg) -# Python project template +# pyrestresource -A nice template to start blank python projets. - -This template automate a lot of handy things and allow CI/CD automatic releases generation. +A RESTful API library built on top of pydantic & uvicorn to make service API from a data model. -It is also collectings data to feed Jenkins build. +/!\\ early in-progress project for internal use ATM. -Checkout [Latest Documentation](https://chacha.ddns.net/mkdocs-web/chacha/{{repository}}/{{branch}}/latest/). +Feel free to contribute. -## Features +Features (available): +- type annotation used +- support containers (dict) +- support plugins (for hook and biding) +- user auth +- ACL +- daemon mode -### Generic pipeline skeleton: - - Prepare - - GetCode - - BuildPackage - - Install - - CheckCode - - PlotMetrics - - RunUnitTests - - GenDOC - - PostRelease - -### CI/CD Environment - - Jenkins - - Gitea (with patch for dynamic Readme variables: https://chacha.ddns.net/gitea/chacha/GiteaMarkupVariable) - - Docker - - MkDocsWeb - -### CI/CD Helper libs - - VirtualEnv - - Changelog generation based on commits - - copier - - pylint + pylint_json2html - - mypy - - unittest + xmlrunner + junitparser + junit2htmlreport - - mkdocs +Features(planned): +- group support +- python internal model instance (with possible serialization/auto-save on-disk) -### Python project - - Full .toml implementation - - .whl automatic generation - - dynamic versionning using git repository - - embedded unit-test \ No newline at end of file +Limitations: +- no nested reads / writes + +Checkout [Latest Documentation](https://chacha.ddns.net/mkdocs-web/chacha/pyrestresource/master/latest/). \ No newline at end of file diff --git a/RUN_quality.launch b/RUN_quality.launch index 079d5ed..95fbee2 100644 --- a/RUN_quality.launch +++ b/RUN_quality.launch @@ -10,7 +10,7 @@ - + diff --git a/RUN_types.launch b/RUN_types.launch new file mode 100644 index 0000000..747af5d --- /dev/null +++ b/RUN_types.launch @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/pyproject.toml b/pyproject.toml index cfea0e9..22b81f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ version_scheme= "post-release" name = "pyrestresource" description = "pyrestresource" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.11" keywords = ["chacha","chacha","template","pyrestresource"] license = { file = "LICENSE.md" } @@ -30,11 +30,13 @@ maintainers = [ classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.11", ] dependencies = [ - 'importlib-metadata; python_version<"3.9"', - 'packaging' + 'packaging', + 'typeguard', + 'pydantic>=2.4,<3', + 'uvicorn>=0.23' ] dynamic = ["version"] @@ -46,21 +48,34 @@ include-package-data = true where = ["src"] [tool.setuptools.package-data] -"pyrestresource.data" = ["*.*"] "pyrestresource" = ["py.typed"] +# [[tool.mypy.overrides]] +# module = "" +# ignore_missing_imports = true + +[tool.coverage.run] +cover_pylib = false +branch = true +data_file="helpers-results/cl_unit_test_raw_coverage/.coverage" +# debug = ["config","multiproc","process"] +parallel = true +concurrency = [ + 'thread' +] + [project.urls] Homepage = "https://chacha.ddns.net/gitea/chacha/pyrestresource" Documentation = "https://chacha.ddns.net/mkdocs-web/chacha/pyrestresource/master/latest/" Tracker = "https://chacha.ddns.net/gitea/chacha/pyrestresource/issues" [project.optional-dependencies] -test = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] -coverage-check = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] -complexity-check = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] -quality-check = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] -type-check = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] -doc-gen = ["chacha_cicd_helper@git+https://chacha.ddns.net/gitea/chacha/chacha_cicd_helper.git@master"] +test = ["chacha_cicd_helper"] +coverage-check = ["chacha_cicd_helper"] +complexity-check = ["chacha_cicd_helper"] +quality-check = ["chacha_cicd_helper"] +type-check = ["chacha_cicd_helper"] +doc-gen = ["chacha_cicd_helper"] # [project.scripts] # my-script = "my_package.module:function" diff --git a/src/pyrestresource/__init__.py b/src/pyrestresource/__init__.py index 7805b58..6e63320 100644 --- a/src/pyrestresource/__init__.py +++ b/src/pyrestresource/__init__.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + # pyrestresource (c) by chacha # # pyrestresource is licensed under a @@ -5,32 +8,58 @@ # # You should have received a copy of the license along with this # work. If not, see . +# pylint: disable=wrong-import-position """ Main module __init__ file. """ -from importlib.metadata import version, distribution, PackageNotFoundError -import warnings +from typing import TYPE_CHECKING -from .test_module import test_function +from .__metadata__ import __version__, __Summuary__, __Name__ -try: # pragma: no cover - __version__ = version("pyrestresource") -except PackageNotFoundError: # pragma: no cover - warnings.warn("can not read __version__, assuming local test context, setting it to ?.?.?") - __version__ = "?.?.?" +from .rest_model import RestField +from .rest_resource_rootpoint import register_rest_rootpoint +from .rest_types import rsrc_verb, T_SupportedRESTFields +from .rest_request_opt import ( + RestRequestParams_POST, + RestRequestParams_DELETE, + RestRequestParams_GET, + RestRequestParams_PUT, + RestRequestParams_RestResourceBase_PUT, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, +) +from .rest_resource_plugin import ( + ResourcePlugin_field_default, + ResourcePlugin_RestResourceBase_default, + ResourcePlugin_dict_default, +) +from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule +from .rest_resource import RestResourceBase +from .rest_login import ( + RestResourceBaseLogin, + UserLogin, +) +from .rest_exceptions import ( + RestResourceException, + RestResourceLoginException, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_ClientChange, + RestResourceLoginException_InvalidSession, + RestResourcePluginException, + RestResourcePluginException_InvalidPluginSignature, + RestResourceHandlerException_Forbiden, +) -try: # pragma: no cover - dist = distribution("pyrestresource") - __Summuary__ = dist.metadata["Summary"] -except PackageNotFoundError: # pragma: no cover - warnings.warn('can not read dist.metadata["Summary"], assuming local test context, setting it to ') - __Summuary__ = "pyrestresource description" - -try: # pragma: no cover - dist = distribution("pyrestresource") - __Name__ = dist.metadata["Name"] -except PackageNotFoundError: # pragma: no cover - warnings.warn('can not read dist.metadata["Name"], assuming local test context, setting it to ') - __Name__ = "pyrestresource" +if TYPE_CHECKING: + from .rest_types import ( + T_ListIndex, + T_ListSize, + T_DictKey, + T_T_DictKey, + T_DictValues, + T_T_DictValues, + ) diff --git a/src/pyrestresource/__metadata__.py b/src/pyrestresource/__metadata__.py new file mode 100644 index 0000000..cfb1763 --- /dev/null +++ b/src/pyrestresource/__metadata__.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# pyrestresource(c) by chacha +# +# pyrestresource is licensed under a +# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. +# +# You should have received a copy of the license along with this +# work. If not, see . + +"""Metadata module module""" + +from importlib.metadata import version, distribution, PackageNotFoundError +import warnings + + +try: # pragma: no cover + __version__ = version("pyrestresource") +except PackageNotFoundError: # pragma: no cover + warnings.warn( + "can not read __version__, assuming local test context, setting it to ?.?.?" + ) + __version__ = "?.?.?" + +try: # pragma: no cover + dist = distribution("pyrestresource") + __Summuary__ = dist.metadata["Summary"] +except PackageNotFoundError: # pragma: no cover + warnings.warn( + 'can not read dist.metadata["Summary"], assuming local test context, setting it to ' + ) + __Summuary__ = "pyrestresource description" + +try: # pragma: no cover + dist = distribution("pyrestresource") + __Name__ = dist.metadata["Name"] +except PackageNotFoundError: # pragma: no cover + warnings.warn( + 'can not read dist.metadata["Name"], assuming local test context, setting it to ' + ) + __Name__ = "pyrestresource" diff --git a/src/pyrestresource/data/.keep b/src/pyrestresource/data/.keep deleted file mode 100644 index e69de29..0000000 diff --git a/src/pyrestresource/data/__init__.py b/src/pyrestresource/data/__init__.py deleted file mode 100644 index 0aef653..0000000 --- a/src/pyrestresource/data/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# pyrestresource (c) by chacha -# -# pyrestresource is licensed under a -# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. -# -# You should have received a copy of the license along with this -# work. If not, see . diff --git a/src/pyrestresource/helpers.py b/src/pyrestresource/helpers.py new file mode 100644 index 0000000..8c569ea --- /dev/null +++ b/src/pyrestresource/helpers.py @@ -0,0 +1,40 @@ +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring + +from __future__ import annotations + +from uuid import UUID +import json +import traceback + +from .rest_types import T_Gen_DictKeys + + +class _JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, T_Gen_DictKeys): # pylint: disable=isinstance-second-argument-not-valid-type + return list(o) + if isinstance(o, UUID): + # if the obj is uuid, we simply return the value of uuid + return str(o) + return json.JSONEncoder.default(self, o) + + +def parse_dict_cookies(cookies: str) -> dict[str, str | None]: + result: dict[str, str | None] = {} + for item in cookies.split(";"): + item = item.strip() + if not item: + continue + if "=" not in item: + result[item] = None + continue + name, value = item.split("=", 1) + result[name] = value + return result + + +def forward_exception(e: Exception, forward: bool) -> None: + if forward: + raise e from None + else: + traceback.print_exc() diff --git a/src/pyrestresource/rest_ACL.py b/src/pyrestresource/rest_ACL.py new file mode 100644 index 0000000..bde356d --- /dev/null +++ b/src/pyrestresource/rest_ACL.py @@ -0,0 +1,45 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from pydantic import BaseModel +from enum import Enum, auto + +from .rest_types import rsrc_verb + +if TYPE_CHECKING is True: + from .rest_login import UserLogin + + +class ACL_target(BaseModel): + pass + + +class ACL_target_user(ACL_target): + name: str + + @classmethod + def from_user_login(cls, user_login: UserLogin) -> ACL_target_user: + return cls(name=user_login.username) + + +class ACL_target_user_Annonymous(ACL_target_user): + name: str = "__ANNONYMOUS__" + + +class ACL_target_group(ACL_target): + name: str + + +class ACL_target_group_Any(ACL_target_group): + name: str = "__ANY__" + + +class ACL_rule(Enum): + ALLOW = auto() + DENY = auto() + + +class ACL_record(BaseModel): + verbs: list[rsrc_verb] + target: ACL_target + rule: ACL_rule diff --git a/src/pyrestresource/rest_exceptions.py b/src/pyrestresource/rest_exceptions.py new file mode 100644 index 0000000..64d582c --- /dev/null +++ b/src/pyrestresource/rest_exceptions.py @@ -0,0 +1,62 @@ +class RestResourceException(Exception): + pass + + +class RestResourceConfigException(RestResourceException): + pass + + +class RestResourceModelException(RestResourceException): + pass + + +class RestResourceModelException_ACL(RestResourceModelException): + pass + + +class RestResourceHandlerException(RestResourceException): + pass + + +class RestResourceHandlerException_ResourceNotFound(RestResourceHandlerException): + pass + + +class RestResourceHandlerException_MethodNotAllowed(RestResourceHandlerException): + pass + + +class RestResourceHandlerException_BadRequest(RestResourceHandlerException): + pass + + +class RestResourceHandlerException_Forbiden(RestResourceHandlerException): + pass + + +class RestResourceLoginException(RestResourceException): + pass + + +class RestResourceLoginException_SessionTimeout(RestResourceLoginException): + pass + + +class RestResourceLoginException_ClientChange(RestResourceLoginException): + pass + + +class RestResourceLoginException_InvalidSession(RestResourceLoginException): + pass + + +class RestResourceLoginException_InvalidCredentials(RestResourceLoginException): + pass + + +class RestResourcePluginException(RestResourceException): + pass + + +class RestResourcePluginException_InvalidPluginSignature(RestResourcePluginException): + pass diff --git a/src/pyrestresource/rest_login.py b/src/pyrestresource/rest_login.py new file mode 100644 index 0000000..4857d7f --- /dev/null +++ b/src/pyrestresource/rest_login.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# pyrestresource(c) by chacha +# +# pyrestresource is licensed under a +# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. +# +# You should have received a copy of the license along with this +# work. If not, see . + +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring + +"""CLI interface module""" +from __future__ import annotations +from typing import Optional, ClassVar, TYPE_CHECKING + +from secrets import token_hex, compare_digest +from datetime import datetime, timedelta +from pydantic import BaseModel + +from .rest_types import rsrc_verb +from .rest_resource import RestResourceBase +from .rest_model import RestField +from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule, ACL_target_user +from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default +from .rest_exceptions import ( + RestResourceLoginException_InvalidCredentials, + RestResourceLoginException_ClientChange, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_InvalidSession, +) + +if TYPE_CHECKING is True: + from .rest_request import RestRequest + from .rest_request_opt import RestRequestParams_RestResourceBase_PUT, RestRequestParams_RestResourceBase_GET + + +class UserLogin(BaseModel): + username: str + secret: str + + +class UserSession(BaseModel): + last_update: datetime + user_login: UserLogin + client: tuple[str, int] | tuple[()] | None + + +class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): + ar_UserLogin: list[UserLogin] = [] + + def handle_resource_get(self, resource: Login, params: RestRequestParams_RestResourceBase_GET) -> Login: + return Login(username=self.get_user_login(), secret=None) + + def handle_resource_put(self, resource: Login, params: RestRequestParams_RestResourceBase_PUT) -> Login: + if resource.username is None or resource.secret is None: + raise RestResourceLoginException_InvalidCredentials() + token = self.user_login(resource.username, resource.secret) + self.set_resp_cookie_value("Authorization", f"Bearer {token}") + return resource + + +class Login(RestResourceBase): + username: Optional[str] = RestField(None) + secret: Optional[str] = RestField( + None, + exclude=True, + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + + +class RestResourceBaseLogin(RestResourceBase): + _ar_user_login: ClassVar[list[UserLogin]] = [] + _ar_user_session: dict[str, UserSession] = {} + _max_session_inactive: ClassVar[timedelta] = timedelta(minutes=20) + _max_session_time: ClassVar[timedelta] = timedelta(hours=12) + login: Login = RestField(default=Login(), plugin=ResourcePlugin_Login) + + def get_new_cookie_expiration_date(self) -> datetime: + return datetime.now() + self._max_session_time + + def _process_request_session(self, request: RestRequest) -> None: + # print(f"[TRACE] {type(self).__name__}->_process_request_session()") + # print(f"[TRACE] request: {id(request)}") + auth_cookie = request.get_cookie("Authorization") + if auth_cookie != None: + if auth_cookie in self._ar_user_session: + # print(f"SESSION FOUND for {request.get_client()}") + + if self._ar_user_session[auth_cookie].client != request.get_client(): + del self._ar_user_session[auth_cookie] + raise RestResourceLoginException_ClientChange() + + time_diff = datetime.now() - self._ar_user_session[auth_cookie].last_update + if time_diff > self._max_session_inactive: + del self._ar_user_session[auth_cookie] + raise RestResourceLoginException_SessionTimeout() + + request.set_user(ACL_target_user(name=self._ar_user_session[auth_cookie].user_login.username)) + # print("SESSION RECOVERED") + return + + raise RestResourceLoginException_InvalidSession() + return + + # print(f"non-connected user {request.get_client()}") + + def user_login(self, user_name: str, user_secret: str, request: RestRequest) -> str: + already_failed: bool = False + + for iter_user_login in self._ar_user_login: + username_ok: bool = compare_digest(user_name, iter_user_login.username) + secret_ok: bool = compare_digest(user_secret, iter_user_login.secret) + + if username_ok is True: + if secret_ok is True and not already_failed: + return self._register_user_session(iter_user_login, request) + else: + already_failed = True + else: + pass + pass + + raise RestResourceLoginException_InvalidCredentials() + + def _register_user_session(self, user_login: UserLogin, request: RestRequest) -> str: + token = token_hex(16) + new_user_session = UserSession(last_update=datetime.now(), user_login=user_login, client=request.get_client()) + self._ar_user_session[f"Bearer {token}"] = new_user_session + return token diff --git a/src/pyrestresource/rest_model.py b/src/pyrestresource/rest_model.py new file mode 100644 index 0000000..3f5051b --- /dev/null +++ b/src/pyrestresource/rest_model.py @@ -0,0 +1,103 @@ +from __future__ import annotations +from typing import ( + Any, + Literal, + Callable, + Optional, + TYPE_CHECKING, +) + +from pydantic.fields import Field, _Unset, PydanticUndefined + +from .rest_exceptions import RestResourceModelException + +if TYPE_CHECKING is True: + from .rest_ACL import ACL_record + from .rest_resource_plugin import ResourcePlugin + from typing import Unpack + from pydantic.fields import _EmptyKwargs, AliasPath, AliasChoices + + +def RestField( + default: Any = PydanticUndefined, + *, + default_factory: Callable[[], Any] | None = _Unset, + alias: str | None = _Unset, + alias_priority: int | None = _Unset, + validation_alias: str | AliasPath | AliasChoices | None = _Unset, + serialization_alias: str | None = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + examples: list[Any] | None = _Unset, + exclude: bool | None = _Unset, + discriminator: str | None = _Unset, + json_schema_extra: dict[str, Any] | Callable[[dict[str, Any]], None] | None = _Unset, + frozen: bool | None = _Unset, + validate_default: bool | None = _Unset, + repr: bool = _Unset, + init_var: bool | None = _Unset, + kw_only: bool | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + union_mode: Literal["smart", "left_to_right"] = _Unset, + ACL: Optional[list[ACL_record]] = _Unset, + plugin: Optional[type[ResourcePlugin]] = _Unset, + **extra: Unpack[_EmptyKwargs], +) -> Any: + if not json_schema_extra or json_schema_extra is _Unset: + if extra: + json_schema_extra = extra # type: ignore + else: + json_schema_extra = {} + + if ACL is not _Unset: + json_schema_extra["ACL"] = ACL + + if plugin is not _Unset: + json_schema_extra["plugin"] = plugin + else: + raise RestResourceModelException("json_schema_extra must not be set") + + return Field( + default, + default_factory=default_factory, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + examples=examples, + exclude=exclude, + discriminator=discriminator, + json_schema_extra=json_schema_extra, + frozen=frozen, + validate_default=validate_default, + repr=repr, + init_var=init_var, + kw_only=kw_only, + pattern=pattern, + strict=strict, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_length=min_length, + max_length=max_length, + union_mode=union_mode, + **extra, + ) diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py new file mode 100644 index 0000000..a774525 --- /dev/null +++ b/src/pyrestresource/rest_request.py @@ -0,0 +1,296 @@ +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring +"""A module to handle http requests context""" + +from __future__ import annotations +from typing import ( + Any, + Generic, + TYPE_CHECKING, +) + +from re import sub +from urllib.parse import urlparse, parse_qs +from pydantic import BaseModel, Field +from typeguard import check_type + +from .rest_login import RestResourceBaseLogin +from .rest_types import rsrc_verb, T_AllSupportedFields +from .rest_request_opt import ( + RestRequestParams_POST, + RestRequestParams_DELETE, + RestRequestParams_GET, + RestRequestParams_PUT, + _T_RestRequestParams, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, +) +from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group +from .helpers import parse_dict_cookies +from .rest_exceptions import ( + RestResourceHandlerException_MethodNotAllowed, + RestResourceHandlerException_BadRequest, + RestResourceException, + RestResourceConfigException, +) + +if TYPE_CHECKING is True: + from typing import Optional + from .rest_types import T_SupportedRESTFields + from .rest_resource import RestResourceBase + + +class RequestFactory( + Generic[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ], + BaseModel, +): + """RestRequets class factory""" + + cls_RestRequestParams_GET: type[RestRequestParams_GET] = Field(default=RestRequestParams_GET) + cls_RestRequestParams_PUT: type[RestRequestParams_PUT] = Field(default=RestRequestParams_PUT) + cls_RestRequestParams_POST: type[RestRequestParams_POST] = Field(default=RestRequestParams_POST) + cls_RestRequestParams_DELETE: type[RestRequestParams_DELETE] = Field(default=RestRequestParams_DELETE) + + def get_RestRequest( + self, root_resource: RestResourceBase, url: str, verb: rsrc_verb, data: dict, query_string: Optional[str] = None + ) -> RestRequest: + """get a RestRequets instance based on LUT_verb configuration + + Args: + url: http url of the request + verb: http verb received + data: data associated with the request + """ + + # /!\ mypy seems not being able to propagate typevar to composed classes + if verb is rsrc_verb.GET: + return RestRequest[RestRequestParams_GET](self.cls_RestRequestParams_GET, root_resource, url, verb, data, query_string) + if verb is rsrc_verb.PUT: + return RestRequest[RestRequestParams_PUT](self.cls_RestRequestParams_PUT, root_resource, url, verb, data, query_string) + if verb is rsrc_verb.POST: + return RestRequest[RestRequestParams_POST](self.cls_RestRequestParams_POST, root_resource, url, verb, data, query_string) + if verb is rsrc_verb.DELETE: + return RestRequest[RestRequestParams_DELETE](self.cls_RestRequestParams_DELETE, root_resource, url, verb, data, query_string) + raise RestResourceHandlerException_MethodNotAllowed("Invalid Verb") + + def update_RestRequest(self, request: RestRequest) -> None: + """create an updated copy of a RestRequest object based on a different LUT_verb configuration + Args: + origin_request: the original request + """ + + # /!\ mypy seems not being able to propagate typevar to composed classes + if request.verb is rsrc_verb.GET: + request.update_ReqParams(self.cls_RestRequestParams_GET) + elif request.verb is rsrc_verb.PUT: + request.update_ReqParams(self.cls_RestRequestParams_PUT) + elif request.verb is rsrc_verb.POST: + request.update_ReqParams(self.cls_RestRequestParams_POST) + elif request.verb is rsrc_verb.DELETE: + request.update_ReqParams(self.cls_RestRequestParams_DELETE) + else: + raise RestResourceHandlerException_MethodNotAllowed("Invalid Verb") + return + + +class RestRequest(Generic[_T_RestRequestParams]): + # pylint: disable=too-many-instance-attributes + """Main RestRequets class""" + + def __init__( + self, + type_request_params: type[_T_RestRequestParams], + root_resource: RestResourceBase, + url: str, + verb: rsrc_verb, + data: Optional[dict[str, T_SupportedRESTFields]] = None, + query_string: Optional[str] = None, + ) -> None: + """class to handle a request context, that will be kept and updated while walking url parts + + Args: + type_request_params: type of the request param + url: http url of the request + verb: http verb received + data: data associated with the request + In this case, all other argument - but type_request_params - are ignored and inherited from the origin_request + query_string: query arguments after url (eg: ?arg1=value1&arg2=value2 ...) + """ + + # defining all types + self.url: str + self.verb: rsrc_verb + self.data: dict + self._raw_headers: list[Any] = [] + self._client: tuple[str, int] | tuple[()] = () + self.headers: dict[str, None | str | dict[str, None | str]] = {"host": None, "cookie": {}} + self._saved_url_params: dict + self.ReqParams: _T_RestRequestParams = type_request_params() + self.url_stack: list[str] + self._saved_url_stack: list[str] + self.url_stack_index: int + self.outgoing_cookie: dict[str, str] = {} + self.user: ACL_target_user = ACL_target_user_Annonymous() + self.groups: list[ACL_target_group] = [] + self.result: Optional[str] = None + self._forced_status: Optional[int] = None + self.root_resource: RestResourceBase = root_resource + + # = or create a fresh one = + if url is None or verb is None or data is None: + raise RestResourceException("url and verb and data must be set") + self.url = url + self.verb = verb + + if data != {} and not check_type(data, T_AllSupportedFields): + raise RestResourceHandlerException_BadRequest(f"Wrong data type received: {data}") + + self.data = data + + # parse_qs returns list[] for all keys, the command convert list to single items so pydantic can eat them :) + if query_string: + self._saved_url_params = dict((k, v if len(v) > 1 else v[0]) for k, v in parse_qs(query_string).items()) + else: + self._saved_url_params = dict((k, v if len(v) > 1 else v[0]) for k, v in parse_qs(urlparse(url).query).items()) + + if type_request_params: + self.ReqParams = type_request_params(**self._saved_url_params) # actual lunch + + self._parse_url(url) + + # keeping a backup of the original url stack + self._saved_url_stack = self.url_stack.copy() + self.url_stack_index = 0 + + def set_resp_status(self, status: int) -> None: + self._forced_status = status + + def get_root_resource(self) -> RestResourceBase: + return self.root_resource + + def get_status(self) -> int: + if self._forced_status is not None: + return self._forced_status + + if self.verb in (rsrc_verb.POST, rsrc_verb.PUT): + return 201 + + return 200 + + def set_client(self, client: tuple[str, int]) -> None: + self._client = client + + def get_client(self) -> tuple[str, int] | tuple[()]: + return self._client + + def set_headers(self, headers: list[Any]) -> None: + self._raw_headers = headers + for elem in self._raw_headers: + if elem[0] == b"host": + self.headers["host"] = elem[1].decode("utf-8") + elif elem[0] == b"user-agent": + self.headers["user-agent"] = elem[1].decode("utf-8") + elif elem[0] == b"cookie": + self.headers["cookie"] = parse_dict_cookies(elem[1].decode("utf-8")) + + def get_cookie(self, key: str) -> str | None: + if self.headers["cookie"] is None: + return None + if key not in self.headers["cookie"]: + return None + if isinstance(self.headers["cookie"], dict): + return self.headers["cookie"][key] + else: + return None + + def set_resp_cookie_value(self, key: str, value: str) -> None: + if not isinstance(self.root_resource, RestResourceBaseLogin): + raise RestResourceConfigException("root_resource must be RestResourceBaseLogin to use user_login") + self.outgoing_cookie[ + key + ] = f"{value}; expires={self.root_resource.get_new_cookie_expiration_date().strftime('%a, %d %b %Y %H:%M:%S GMT')}; path=/; HttpOnly" + + def reset_resp_cookie(self, key: str) -> None: + self.outgoing_cookie[key] = "null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT" + + def get_host(self) -> str | dict[str, str | None] | None: + return self.headers["host"] + + def set_result(self, result: str): + self.result = result + + def get_result(self) -> Optional[str]: + return self.result + + def set_user(self, user: ACL_target_user): + self.user = user + + def get_user(self): + return self.user + + def add_group(self, group: ACL_target_group): + self.groups.append(group) + + def update_ReqParams(self, type_request_params: type[_T_RestRequestParams]): + self.ReqParams = type_request_params(**self._saved_url_params) + + def _parse_url(self, url: str) -> None: + # remove repeated slash ('/') + url = sub(r"\/{2,}", "/", url) + # root url need to be added manually because it is trimmed by the url split function + self.url_stack = ["/"] + self.url_stack.extend([_ for _ in urlparse(url).path.split("/") if _ != ""]) + + def reset_url_stack(self) -> None: + self.url_stack = self._saved_url_stack.copy() + self.url_stack_index = 0 + + def get_url_stack(self) -> list[str]: + """retrieve the current url stack""" + + return self.url_stack + + def get_url(self) -> str: + """retrieve the raw url""" + return self.url + + def consume_url_stack(self, count: int) -> list[str]: + """consume some url stack elements + + Args: + count: number of element to consume + """ + + returned_stack: list[str] = [] + + for _ in range(count): + returned_stack.append(self.url_stack.pop(0)) + self.url_stack_index = self.url_stack_index + count + return returned_stack + + def get_data(self) -> dict: + """get the request data""" + return self.data + + def get_resource_origin(self, deepness: int = 0) -> str: + """get current or previous (consumed) resource in the url + + Args: + deepness: backward amount + """ + + return self._saved_url_stack[self.url_stack_index - deepness] + + def get_req_params(self) -> _T_RestRequestParams: + """get extracted req_params""" + return self.ReqParams + + def get_verb(self) -> rsrc_verb: + """get http request verb""" + return self.verb diff --git a/src/pyrestresource/rest_request_opt.py b/src/pyrestresource/rest_request_opt.py new file mode 100644 index 0000000..599bdbe --- /dev/null +++ b/src/pyrestresource/rest_request_opt.py @@ -0,0 +1,76 @@ +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring + +from __future__ import annotations +from typing import Generic, Optional, TypeVar, TYPE_CHECKING + +from pydantic import BaseModel, Extra + +from .rest_types import ( + _T_DictKey, +) + +if TYPE_CHECKING is True: + pass + + +class RestRequestParams(BaseModel, extra=Extra.allow): + pass + + +class RestRequestParams_POST(RestRequestParams): + pass + + +class RestRequestParams_DELETE(RestRequestParams): + pass + + +class RestRequestParams_GET(RestRequestParams): + pass + + +class RestRequestParams_PUT(RestRequestParams): + pass + + +class RestRequestParams_RestResourceBase(RestRequestParams): + pass + + +class RestRequestParams_RestResourceBase_PUT(RestRequestParams_PUT, RestRequestParams_RestResourceBase): + pass + + +class RestRequestParams_RestResourceBase_GET(RestRequestParams_GET, RestRequestParams_RestResourceBase): + pass + + +class RestRequestParams_Dict(RestRequestParams): + pass + + +class RestRequestParams_Dict_POST(RestRequestParams_Dict, RestRequestParams_POST, Generic[_T_DictKey]): + API_key: Optional[_T_DictKey] = None + + +class RestRequestParams_Dict_DELETE(RestRequestParams_Dict, RestRequestParams_DELETE, Generic[_T_DictKey]): + API_key: Optional[_T_DictKey] = None + + +class RestRequestParams_Dict_GET(RestRequestParams_Dict, RestRequestParams_GET): + pass + + +class RestRequestParams_Dict_elem_GET(RestRequestParams_Dict, RestRequestParams_GET): + pass + + +class RestRequestParams_Dict_elem_PUT(RestRequestParams_Dict, RestRequestParams_GET): + pass + + +_T_RestRequestParams = TypeVar("_T_RestRequestParams", bound=RestRequestParams) +_T_RestRequestParams_POST = TypeVar("_T_RestRequestParams_POST", bound=RestRequestParams_POST) +_T_RestRequestParams_DELETE = TypeVar("_T_RestRequestParams_DELETE", bound=RestRequestParams_DELETE) +_T_RestRequestParams_GET = TypeVar("_T_RestRequestParams_GET", bound=RestRequestParams_GET) +_T_RestRequestParams_PUT = TypeVar("_T_RestRequestParams_PUT", bound=RestRequestParams_PUT) diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py new file mode 100644 index 0000000..aca8879 --- /dev/null +++ b/src/pyrestresource/rest_resource.py @@ -0,0 +1,251 @@ +from __future__ import annotations +from typing import ( + Any, + ClassVar, + Optional, + TYPE_CHECKING, +) + +from abc import ABC +import json +import pprint + +from pydantic import BaseModel + +from .rest_types import rsrc_verb +from .helpers import _JSONEncoder, forward_exception + +from .rest_ACL import ( + ACL_record, + ACL_target_user, + ACL_target_group, + ACL_target_group_Any, + ACL_rule, +) + +from .rest_exceptions import ( + RestResourceLoginException_InvalidSession, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_ClientChange, + RestResourceLoginException_InvalidCredentials, + RestResourceHandlerException_ResourceNotFound, + RestResourceHandlerException_MethodNotAllowed, + RestResourceHandlerException_BadRequest, + RestResourceHandlerException_Forbiden, + RestResourceException, +) + +if TYPE_CHECKING is True: + from .rest_request import RestRequest + from .rest_types import T_SupportedRESTFields + from .rest_resource_plugin import ResourcePlugin + from .rest_types import ( + T_T_DictKey, + T_T_DictValues, + ) + + +class RestResourceBase(ABC, BaseModel, validate_assignment=True): + _dict_key_type_: ClassVar[dict[str, T_T_DictKey]] = {} + _dict_value_type_: ClassVar[dict[str, T_T_DictValues]] = {} + _model_dump_excluded_: ClassVar[dict[str, bool]] = {} + _primary_key_: ClassVar[Optional[str]] = None + _plugins_: ClassVar[ + dict[ + str, + type[ResourcePlugin], + ] + ] = {} + _ACL_record_: ClassVar[ + dict[ + str, + list[ACL_record], + ] + ] = {} + + def _check_acl(self, user: ACL_target_user, groups: list[ACL_target_group], verb: rsrc_verb, field: str, is_self: bool = True): + # print(f"evaluate self ACLs rule: {self._ACL_record_}") + # print(f"user: {user}") + # print(f"groups: {groups}") + if is_self and verb is rsrc_verb.GET and self.model_fields[field].exclude is True: + # print("ALLOWED (excluded field)") + return + for acl in self._ACL_record_[field]: + # print(f"evaluate ACL rule: {acl}") + if verb in acl.verbs: + if isinstance(acl.target, ACL_target_user): + if user == acl.target: + if acl.rule is ACL_rule.ALLOW: + # print("ALLOWED (user)") + return + raise RestResourceHandlerException_Forbiden(f"Not allowed access detected: {field}") + elif isinstance(acl.target, ACL_target_group): + if isinstance(acl.target, ACL_target_group_Any) or any(_ for _ in groups if _.name == acl.target.name): + if acl.rule is ACL_rule.ALLOW: + # print("ALLOWED (group)") + return + raise RestResourceHandlerException_Forbiden(f"Not allowed access detected: {field}") + else: + raise RestResourceException(f"Wrong ACL target type: {field}") + # print("ALLOWED (Default)") + + def check_acl_field(self, request: RestRequest, req_index: int = 0) -> None: + """Check ACL on requested field access""" + self._check_acl(request.user, request.groups, request.get_verb(), request.get_resource_origin(req_index), False) + + def check_acl_self(self, request: RestRequest, new_data: Optional[dict[str, T_SupportedRESTFields]]) -> None: + """Check ACL on requested field operation (involving checking sub-fields)""" + if request.get_verb() is rsrc_verb.GET: + for key in self.model_fields.keys(): + self._check_acl(request.user, request.groups, rsrc_verb.GET, key) + elif request.get_verb() is rsrc_verb.PUT: + if new_data is not None: + for key in new_data.keys(): + if key in self.model_fields: + self._check_acl(request.user, request.groups, rsrc_verb.PUT, key) + else: + raise RestResourceException("Incompatible verb") + + def update(self, **new_data): + for field, value in new_data.items(): + setattr(self, field, value) + + async def read_body(self, receive): + """ + Read and return the entire body from an incoming ASGI message. + """ + body = b"" + more_body = True + + while more_body: + message = await receive() + body += message.get("body", b"") + more_body = message.get("more_body", False) + + return body + + async def __call__(self, scope, receive, send) -> None: + assert scope["type"] == "http" + + method = scope["method"] + + assert method in ["GET", "DELETE", "PUT", "POST"] + + if b"content-type" in scope["headers"]: + assert scope["headers"][b"content-type"] == b"application/json" + + # pprint.pprint(scope) + + body = await self.read_body(receive) + + request: RestRequest = self.process_request( + scope["path"], + rsrc_verb[scope["method"]], + body.decode("utf-8"), + scope["query_string"].decode("utf-8"), + scope["client"], + scope["headers"], + True, + ) + + assert request is not None + + header_resp: dict[str, Any] = { + "type": "http.response.start", + "status": request.get_status(), + "headers": [ + [b"content-type", b"application/json"], + ], + } + + for name, value in request.outgoing_cookie.items(): + header_resp["headers"].append(["Set-Cookie", f"{name}={value}"]) + + await send(header_resp) + + body = None + result = request.get_result() + if result: + body = result.encode("utf-8") + + await send( + { + "type": "http.response.body", + "body": body, + } + ) + + def _process_request_session(self, request: RestRequest) -> None: + pass + + def process_request( + self, + url: str, + verb: rsrc_verb = rsrc_verb.GET, + data_json: Optional[str] = None, + query_string: Optional[str] = None, + client: Optional[tuple[str, int]] = None, + headers: Optional[list[Any]] = None, + http_mode: bool = False, + ) -> RestRequest: + from .rest_resource_handler import ( + ResourceHandler, + ResourceHandler_RestResourceBase, + ) + + data: dict = {} + if data_json: + data = json.loads(data_json) + + # creating the root handler + ressource_handler: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string) + + # preparing request & session + request: RestRequest = ressource_handler.get_request() + assert request is not None + + if headers is not None: + request.set_headers(headers) + + if client is not None: + request.set_client(client) + + try: + self._process_request_session(request) + + result = ressource_handler.process_verb() + + if isinstance(result, RestResourceBase): + request.set_result(json.dumps(result.model_dump(mode="json"))) + elif result is not None: + request.set_result(json.dumps(result, cls=_JSONEncoder)) + else: + request.set_result("null") + + except RestResourceHandlerException_ResourceNotFound as e: + request.set_resp_status(404) + forward_exception(e, not http_mode) + + except RestResourceHandlerException_MethodNotAllowed as e: + request.set_resp_status(405) + forward_exception(e, not http_mode) + + except RestResourceHandlerException_BadRequest as e: + request.set_resp_status(400) + forward_exception(e, not http_mode) + + except RestResourceHandlerException_Forbiden as e: + request.set_resp_status(403) + forward_exception(e, not http_mode) + + except ( + RestResourceLoginException_InvalidSession, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_ClientChange, + RestResourceLoginException_InvalidCredentials, + ) as e: + request.set_resp_status(401) + request.reset_resp_cookie("Authorization") + forward_exception(e, not http_mode) + + return request diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py new file mode 100644 index 0000000..6435cbf --- /dev/null +++ b/src/pyrestresource/rest_resource_handler.py @@ -0,0 +1,665 @@ +from __future__ import annotations +from typing import Optional, cast, TypeVar, Generic, Self, TYPE_CHECKING + +import abc + +from .rest_types import ( + NoneType, + rsrc_verb, + T_SupportedRESTFields, + T_DictKey, + _T_SupportedRESTFields, + T_Dict, + T_DictValues, +) +from .rest_resource import RestResourceBase +from .rest_request import RequestFactory +from .rest_resource_plugin import ( + ResourcePlugin_field, + ResourcePlugin_dict, + ResourcePlugin_RestResourceBase, +) +from .rest_request_opt import ( + RestRequestParams_POST, + RestRequestParams_DELETE, + RestRequestParams_GET, + RestRequestParams_PUT, + RestRequestParams_RestResourceBase_PUT, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, +) + +from .rest_exceptions import ( + RestResourceHandlerException, + RestResourceHandlerException_ResourceNotFound, + RestResourceHandlerException_MethodNotAllowed, + RestResourceHandlerException_BadRequest, + RestResourceHandlerException_Forbiden, +) + +if TYPE_CHECKING is True: + from .rest_types import T_T_DictKey, T_T_DictValues + from .rest_request import RestRequest + +_T_Resource = TypeVar("_T_Resource", T_DictValues, T_Dict, T_SupportedRESTFields, RestResourceBase) + + +class ResourceHandler( + abc.ABC, + Generic[ + _T_Resource, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ], +): + _ar_resource_handler_cls_: list[type[ResourceHandler]] = [] + _nb_url_element_to_consume_ = 1 + + _request_factory: RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ] = RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ]( + cls_RestRequestParams_GET=RestRequestParams_GET, + cls_RestRequestParams_PUT=RestRequestParams_PUT, + cls_RestRequestParams_POST=RestRequestParams_POST, + cls_RestRequestParams_DELETE=RestRequestParams_DELETE, + ) + + def __init__( + self, + resource: _T_Resource, + url: Optional[str] = None, + verb: Optional[rsrc_verb] = None, + data: Optional[dict] = None, + query_string: Optional[str] = None, + prev_handler: Optional[ResourceHandler] = None, + ) -> None: + self.prev_handler: Optional[ResourceHandler] = None + self.next_handler: Optional[ResourceHandler] = None + self.saved_url: list[str] = [] + self.resource: _T_Resource = resource + self.req: RestRequest + if prev_handler is not None: + self.prev_handler = prev_handler + self.req = prev_handler.get_request() + self._request_factory.update_RestRequest(self.req) + + elif None in [url, verb]: + raise RestResourceHandlerException("if req not set, url,verb must be setted") + else: + assert url is not None and verb is not None + assert isinstance(resource, RestResourceBase) + if data is None: + data = {} + self.req = self._request_factory.get_RestRequest(resource, url, verb, data, query_string) + + # print(f"[TRACE] creating {type(self).__name__}() with url={self.req.get_url_stack()}") + + @classmethod + def create_chained_handler(cls, other: ResourceHandler, resource: _T_Resource) -> Self: + return cls(resource, None, None, None, None, other) + + @classmethod + @abc.abstractmethod + def _check_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> bool: + return False + + @classmethod + def _get_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> type[ResourceHandler]: + for resource_handler_cls in cls._ar_resource_handler_cls_: + if resource_handler_cls._check_resource_handler(resource, req): + # print(f"[DEBUG] match ResourceHandler: {resource_handler_cls.__name__}") + return resource_handler_cls + raise RestResourceHandlerException(f"Unsupported Resource Type {type(resource).__name__}") + + @classmethod + def register_resource_handler(cls, other_cls) -> None: + cls._ar_resource_handler_cls_.append(other_cls) + return other_cls + + def get_request(self) -> RestRequest: + return self.req + + def process_verb( + self, + ) -> Optional[_T_Resource | T_DictKey | list[T_DictKey]]: + # print(f"[TRACE] {type(self).__name__}->process_verb()") + self._reset_context() + resource_handler = self._find_resource() + return resource_handler._process_verb() + + def access_resource( + self, + ) -> _T_Resource: + # print(f"[TRACE] {type(self).__name__}->access_resource()") + self._reset_context() + resource_handler = self._find_resource() + return resource_handler.resource + + def _reset_context(self) -> None: + self.req.reset_url_stack() + + def _find_resource( + self, + ) -> ResourceHandler[ + _T_Resource, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ]: + # print(f"[TRACE] {type(self).__name__}->_find_resource()") + # print(f"[DEBUG] {type(self).__name__}->resource = {type(self.resource).__name__}") + + if len(self.req.get_url_stack()) == 0: + return self + + self._check_access_rights() + + next_resource = self._process_get() + # reveal_type(next_resource) + _next_resource = cast(_T_Resource, next_resource) + # reveal_type(_next_resource) + # print(f"[DEBUG] next_resource = {type(next_resource).__name__}") + + if ( + isinstance(_next_resource, RestResourceBase) + or isinstance(_next_resource, dict) + or type(_next_resource) in _T_SupportedRESTFields + ): + next_resource_handler_cls: type[ResourceHandler] = self._get_resource_handler(_next_resource, self.req) + + self.saved_url = self.req.consume_url_stack(self._nb_url_element_to_consume_) + + # we always create a new ResourceHandler context because we might want to access + # previous saved/chained ones (for exemple to put/post/delete values in containers) + # if next_resource_handler_cls != type(self): + # print(f"[DEBUG] CHANGING HANDLER to {next_resource_handler_cls.__name__}") + + next_resource_handler: ResourceHandler = next_resource_handler_cls.create_chained_handler(self, _next_resource)._find_resource() + self.next_handler = next_resource_handler + return next_resource_handler + + # in _find_resource context, only resource's real values can be retrieved + raise RestResourceHandlerException_ResourceNotFound() + + def _check_access_rights(self): + pass + + def _process_verb( + self, + ) -> Optional[_T_Resource | T_DictKey | list[T_DictKey]]: + # print(f"[TRACE] {type(self).__name__}->_process_verb()") + + verb = self.req.get_verb() + + if verb is rsrc_verb.GET: + return self._process_get() + if verb is rsrc_verb.PUT: + self._process_put() + return None + if verb is rsrc_verb.POST: + return self._process_post() + if verb is rsrc_verb.DELETE: + self._process_delete() + return None + + raise RestResourceHandlerException_BadRequest("Invalid Verb") + + def _process_get( + self, + ) -> _T_Resource | list[T_DictKey]: + return self._handle_process_get(self.req.get_req_params()) + + def _process_put(self) -> None: + self._handle_process_put(self.req.get_req_params()) + + def _process_post( + self, + ) -> Optional[T_DictKey]: + return self._handle_process_post(self.req.get_req_params()) + + def _process_delete( + self, + ) -> None: + self._handle_process_delete(self.req.get_req_params()) + + def _handle_process_get(self, params: _T_RestRequestParams_GET) -> _T_Resource | list[T_DictKey]: + raise RestResourceHandlerException_MethodNotAllowed(f"GET method not implemented for {type(self).__name__}") + + def _handle_process_put(self, params: _T_RestRequestParams_PUT) -> None: + raise RestResourceHandlerException_MethodNotAllowed(f"PUT method not implemented for {type(self).__name__}") + + def _handle_process_post(self, params: _T_RestRequestParams_POST) -> Optional[T_DictKey]: + raise RestResourceHandlerException_MethodNotAllowed(f"POST method not implemented for {type(self).__name__}") + + def _handle_process_delete(self, params: _T_RestRequestParams_DELETE) -> None: + raise RestResourceHandlerException_MethodNotAllowed(f"DELETE method not implemented for {type(self).__name__}") + + +@ResourceHandler.register_resource_handler +class ResourceHandler_dict( + ResourceHandler[ + T_Dict, + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, + _T_RestRequestParams_PUT, + ] +): + _nb_url_element_to_consume_ = 0 + + _request_factory: RequestFactory[ + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, + _T_RestRequestParams_PUT, + ] = RequestFactory[ + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, + _T_RestRequestParams_PUT, + ]( + cls_RestRequestParams_GET=RestRequestParams_Dict_GET, + cls_RestRequestParams_POST=RestRequestParams_Dict_POST, + cls_RestRequestParams_DELETE=RestRequestParams_Dict_DELETE, + ) + + @classmethod + def _check_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> bool: + # print(f"{cls.__name__}->_check_resource_handler()") + + return isinstance(resource, dict) and len(req.get_url_stack()) == 1 + + def _handle_process_get(self, params) -> list[T_DictKey]: + # print(f"{type(self).__name__}->_process_get()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + _dict: dict[T_DictKey, T_DictValues] = cast(dict[T_DictKey, T_DictValues], self.resource) + + return list(_dict.keys()) + + def _handle_process_delete(self, params) -> None: + # print(f"{type(self).__name__}->_handle_process_delete()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + assert self.prev_handler is not None + + dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] + + _dict: dict[T_DictKey, T_DictValues] = cast(dict[T_DictKey, T_DictValues], self.resource) + + if params.API_key is not None: + del _dict[dict_key_type(params.API_key)] + else: + _dict.clear() + return + + def _handle_process_post(self, params) -> Optional[T_DictKey]: + # pylint: disable=protected-access + + # print(f"{type(self).__name__}->_handle_process_post()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + assert self.prev_handler is not None + + dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] + + dict_value_type: T_T_DictValues = cast(RestResourceBase, self.prev_handler.resource)._dict_value_type_[ + self.req.get_resource_origin(1) + ] + + _obj: T_DictValues + if not issubclass(dict_value_type, NoneType): # type: ignore # => mypy bug with type[None] + _obj = dict_value_type(**self.req.get_data()) # type: ignore # => mypy bug with type[None] + else: + _obj = None + + _dict: dict[T_DictKey, T_DictValues] = cast(dict[T_DictKey, T_DictValues], self.resource) + + # 1st try/ using request param provided dict API_key + if params.API_key is not None: + # if a primary key is set for the resource, updating it + if isinstance(_obj, RestResourceBase): + if _obj._primary_key_ is not None: + _pri: T_DictKey = dict_key_type(params.API_key) + setattr(_obj, _obj._primary_key_, _pri) + # storing resource + _dict[dict_key_type(params.API_key)] = _obj + return dict_key_type(params.API_key) + + # 2nd try/ using provided resource internal primary key + # & 3rd try/ using resource internal auto-generated primary key + # => this case is automatic because if self.req.get_data() doesn't contain the key, it should be automatically created + if isinstance(_obj, RestResourceBase): + if _obj._primary_key_ is not None: + _obj_primary_key: Optional[T_DictKey] = getattr(_obj, _obj._primary_key_) + if _obj_primary_key is not None: + _dict[_obj_primary_key] = _obj + return _obj_primary_key + + raise RestResourceHandlerException_BadRequest( + "Either the object needs defined primary key or the request must contain an API_key param to process this command" + ) + return None # for mypy.... + + +@ResourceHandler.register_resource_handler +class ResourceHandler_dict_elem( + ResourceHandler[ + T_DictValues, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + _T_RestRequestParams_PUT, + ] +): + _nb_url_element_to_consume_ = 1 + + _request_factory: RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + _T_RestRequestParams_PUT, + ] = RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + _T_RestRequestParams_PUT, + ]( + cls_RestRequestParams_GET=RestRequestParams_RestResourceBase_GET + ) + + @classmethod + def _check_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> bool: + # print(f"{cls.__name__}->_check_resource_handler()") + + return isinstance(resource, dict) and len(req.get_url_stack()) > 1 + + def _handle_process_get(self, params) -> T_DictValues: + # print(f"{type(self).__name__}->_process_get()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + assert self.prev_handler is not None + + dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] + + if issubclass(dict_key_type, bytes): + key_byte = dict_key_type(self.req.get_resource_origin(0), "utf-8") + return cast(dict[T_DictKey, T_DictValues], self.resource)[key_byte] + else: + key = dict_key_type(self.req.get_resource_origin(0)) + return cast(dict[T_DictKey, T_DictValues], self.resource)[key] + + def _handle_process_delete(self, params) -> None: + # print(f"{type(self).__name__}->_handle_process_delete()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + # this is a litle bit tricky because this call comes from a next resource, so get_resource_origin(2) + # instead of expected get_resource_origin(1) because we need to go backward + # because self.req is another context that is not saved to improve performances + + assert self.prev_handler is not None + + dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(2)] + + if issubclass(dict_key_type, bytes): + key_byte = dict_key_type(self.req.get_resource_origin(1), "utf-8") + del cast(dict[T_DictKey, T_DictValues], self.resource)[key_byte] + else: + key = dict_key_type(self.req.get_resource_origin(1)) + del cast(dict[T_DictKey, T_DictValues], self.resource)[key] + + +@ResourceHandler.register_resource_handler +class ResourceHandler_RestResourceBase( + ResourceHandler[ + RestResourceBase, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_RestResourceBase_PUT, + ] +): + _request_factory: RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_RestResourceBase_PUT, + ] = RequestFactory[ + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_RestResourceBase_PUT, + ]( + cls_RestRequestParams_GET=RestRequestParams_RestResourceBase_GET, + cls_RestRequestParams_PUT=RestRequestParams_RestResourceBase_PUT, + ) + + @classmethod + def _check_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> bool: + # print(f"{cls.__name__}->_check_resource_handler()") + + return isinstance(resource, RestResourceBase) + + def _check_access_rights(self) -> None: + super()._check_access_rights() + # print(f"{type(self).__name__}->_check_access_rights()") + + if self.req.get_resource_origin(0) == "/": + return + + # print("==================") + # print(self.req.get_resource_origin(0)) + # print(len(self.req.get_url_stack())) + # print(self.resource._model_dump_excluded_) + # print(type(self.resource)) + # print(self.resource.exclude) + + if self.req.get_resource_origin(0) not in self.resource.model_fields: + raise RestResourceHandlerException_ResourceNotFound(f"Unknown field access detected: {self.req.get_url_stack()}") + + self.resource.check_acl_field(self.req) + + if len(self.req.get_url_stack()) == 0: # destination reached + if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET: + raise RestResourceHandlerException_ResourceNotFound(f"Not allowed READ access detected: {self.req.get_url_stack()}") + + def _handle_process_get(self, params) -> RestResourceBase: + # print(f"{type(self).__name__}->_process_get()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + # CASE 1: no more item in url_stack => we reached the endpoint (operation) + # So we are in a RestResourceBase instance and must return the content + plugin_field: ResourcePlugin_field + plugin_resource: ResourcePlugin_RestResourceBase + + if len(self.req.get_url_stack()) == 0: + self.resource.check_acl_self(self.req, None) + for key, attr in self.resource.model_fields.items(): + if key in self.resource._plugins_: + if issubclass(self.resource._plugins_[key], ResourcePlugin_field): + plugin_field = cast(ResourcePlugin_field, self.resource._plugins_[key](self.req, self.req.get_root_resource())) + value = getattr(self.resource, key) + setattr(self.resource, key, plugin_field.handle_field_get(value, params)) + elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): + plugin_resource = cast( + ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.req.get_root_resource()) + ) + value = getattr(self.resource, key) + setattr(self.resource, key, plugin_resource.handle_resource_get(value, params)) + + # result = RestResourceWalker_Root__handler(self.resource).process() + # print(result) + return self.resource + + # CASE 2: specific (operation) case for root Node + # TODO: this must probably be merged with the previous bloc + if self.req.get_resource_origin(0) == "/": + return self.resource + + # CASE 3: in between (access) + self.resource.check_acl_field(self.req) + value = getattr(self.resource, self.req.get_resource_origin(0)) + + key = self.req.get_resource_origin(0) + if key in self.resource._plugins_: + if issubclass(self.resource._plugins_[key], ResourcePlugin_field): + plugin_field = cast( + ResourcePlugin_field, + self.resource._plugins_[key](self.req, self.req.get_root_resource()), + ) + value = plugin_field.handle_field_get(value, params) + + elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): + plugin_resource = cast( + ResourcePlugin_RestResourceBase, + self.resource._plugins_[key](self.req, self.req.get_root_resource()), + ) + value = plugin_resource.handle_resource_get(value, params) + + return value + + def _handle_process_put(self, params) -> None: + # print(f"{type(self).__name__}->_process_put()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + self.resource.check_acl_self(self.req, self.req.get_data()) + + # creating a copy of the current resource + _new_resrc = self.resource.copy() + # updating values based on nex data + _new_resrc.update(**self.req.get_data()) + + # applying plugins (to nested element) + if isinstance(_new_resrc, RestResourceBase): + for key, attr in _new_resrc.model_fields.items(): + if key in _new_resrc._plugins_: + if issubclass(_new_resrc._plugins_[key], ResourcePlugin_field): + plugin_field: ResourcePlugin_field = cast( + ResourcePlugin_field, _new_resrc._plugins_[key](self.req, self.req.get_root_resource()) + ) + value = getattr(_new_resrc, key) + setattr(_new_resrc, key, plugin_field.handle_field_put(value, params)) + + # applying plugins (from parent element) + if self.prev_handler is not None: + # element is within a dict + if ( + isinstance(self.prev_handler.resource, dict) + and self.prev_handler.prev_handler is not None + and isinstance(self.prev_handler.prev_handler.resource, RestResourceBase) + ): + key = self.req.get_resource_origin(2) + if key in self.prev_handler.prev_handler.resource._plugins_: + plugin_dict: ResourcePlugin_dict = cast( + ResourcePlugin_dict, + self.prev_handler.prev_handler.resource._plugins_[key](self.req, self.req.get_root_resource()), + ) + _new_resrc = plugin_dict.handle_dict_elem_put(_new_resrc, params) + # element is within a RestResourceBase + elif isinstance(self.prev_handler.resource, RestResourceBase): + key = self.req.get_resource_origin(1) + if key in self.prev_handler.resource._plugins_: + plugin_rsrc: ResourcePlugin_RestResourceBase = cast( + ResourcePlugin_RestResourceBase, + self.prev_handler.resource._plugins_[key](self.req, self.req.get_root_resource()), + ) + _new_resrc = plugin_rsrc.handle_resource_put(_new_resrc, params) + + self.resource.update(**_new_resrc.__dict__) + return + + def _handle_process_delete(self, params) -> None: + # print(f"{type(self).__name__}->_handle_process_delete()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + # DELETING an element can only be done from a dict => checking and forwarding + if ( + self.prev_handler is not None + and isinstance(self.prev_handler.resource, dict) + and self.prev_handler.prev_handler is not None + and isinstance(self.prev_handler.prev_handler.resource, RestResourceBase) + ): + self.prev_handler._process_delete() + else: + raise RestResourceHandlerException_BadRequest("cannot delete an element outside a dict") + + +@ResourceHandler.register_resource_handler +class ResourceHandler_simple( + ResourceHandler[ + T_SupportedRESTFields, + _T_RestRequestParams_POST, + _T_RestRequestParams_DELETE, + _T_RestRequestParams_GET, + _T_RestRequestParams_PUT, + ] +): + @classmethod + def _check_resource_handler(cls, resource: _T_Resource, req: RestRequest) -> bool: + # print(f"{cls.__name__}->_check_resource_handler()") + + return type(resource) in _T_SupportedRESTFields + + def _handle_process_get(self, params) -> T_SupportedRESTFields: + # print(f"{type(self).__name__}->_process_get()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + assert self.prev_handler is not None + assert isinstance(self.prev_handler.resource, RestResourceBase) + + self.prev_handler.resource.check_acl_field(self.req, 1) + + if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: + plugin_simple: ResourcePlugin_field = cast( + ResourcePlugin_field, + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.req.get_root_resource()), + ) + return plugin_simple.handle_field_get(self.resource, params) + + return self.resource + + def _handle_process_put(self, params) -> None: + # print(f"{type(self).__name__}->_process_put()") + # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + + assert self.prev_handler is not None + assert isinstance(self.prev_handler.resource, RestResourceBase) + + self.prev_handler.resource.check_acl_field(self.req, 1) + + value = self.req.get_data() + + if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: + # print("PLUGIN FOUND") + plugin_simple: ResourcePlugin_field = cast( + ResourcePlugin_field, + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.req.get_root_resource()), + ) + # print(value) + value = plugin_simple.handle_field_put(value, params) + # print(value) + + # print(self.req.get_resource_origin(1)) + setattr( + self.prev_handler.resource, + self.req.get_resource_origin(1), + value, + ) + # print(self.prev_handler.resource) diff --git a/src/pyrestresource/rest_resource_handler_walker.py b/src/pyrestresource/rest_resource_handler_walker.py new file mode 100644 index 0000000..62f0df9 --- /dev/null +++ b/src/pyrestresource/rest_resource_handler_walker.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# pyrestresource(c) by chacha +# +# pyrestresource is licensed under a +# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. +# +# You should have received a copy of the license along with this +# work. If not, see . + +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring + +"""CLI interface module""" +from __future__ import annotations +from typing import TYPE_CHECKING + +from .rest_resource_walker import ( + RestResourceWalkerFutureResult, + RestResourceWalker_Root, + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, +) + +if TYPE_CHECKING is True: + from typing import Optional, Any + + +class RestResourceWalkerFutureResult_RestResourceBase_handler(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + # print(f"RestResourceWalkerFutureResult_RestResourceBase_handler {result}") + res: dict[str, Any] = {} + res[self.source.resource_name] = {} + if result: + for subres in result: + key = next(iter(subres)) + print(key) + res[self.source.resource_name] = res[self.source.resource_name] | subres + return res + + +class RestResourceWalkerFutureResult_Dict_handler(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + # print(f"RestResourceWalkerFutureResult_Dict_handler {result}") + res: dict[str, Any] = {} + if result: + for subres in result: + res = res | subres + return res + + +class RestResourceWalkerFutureResult_RestFields_handler(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + # print(f"RestResourceWalkerFutureResult_RestFields_handler {result}") + # print(self.source.resource) + res: dict[str, Any] = {} + res[self.source.resource_name] = {} + if result: + for subres in result: + key = next(iter(subres)) + print(key) + res[self.source.resource_name] = res[self.source.resource_name] | subres + return res + + +class RestResourceWalker_Sub_T_Dict__handler(RestResourceWalker_Sub_T_Dict): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_Dict_handler + + +class RestResourceWalker_Sub_RestResourceBase__handler(RestResourceWalker_Sub_RestResourceBase): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_RestResourceBase_handler + + +class RestResourceWalker_Sub_RestResourceFields__handler(RestResourceWalker_Sub_RestFields): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_RestFields_handler + + +class RestResourceWalker_Root__handler(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict__handler, + RestResourceWalker_Sub_RestResourceFields__handler, + RestResourceWalker_Sub_RestResourceBase__handler, + ] diff --git a/src/pyrestresource/rest_resource_plugin.py b/src/pyrestresource/rest_resource_plugin.py new file mode 100644 index 0000000..a110c1e --- /dev/null +++ b/src/pyrestresource/rest_resource_plugin.py @@ -0,0 +1,203 @@ +from __future__ import annotations +from typing import Optional, Generic, TYPE_CHECKING + +from abc import abstractmethod, ABC +from datetime import datetime + +from .rest_types import ( + _T_DictValues, + _T_DictKey, + TV_SupportedRESTFields, + TV_RestResourceBase, +) +from .rest_exceptions import RestResourceConfigException + + +if TYPE_CHECKING is True: + from .rest_request import RestRequest + from .rest_resource import RestResourceBase + from .rest_request_opt import ( + RestRequestParams_GET, + RestRequestParams_PUT, + RestRequestParams_RestResourceBase_PUT, + RestRequestParams_RestResourceBase_GET, + RestRequestParams_Dict_POST, + RestRequestParams_Dict_DELETE, + RestRequestParams_Dict_GET, + RestRequestParams_Dict_elem_GET, + RestRequestParams_Dict_elem_PUT, + ) + + +class ResourcePlugin(ABC): + def __init__(self, request: RestRequest, root_resource: RestResourceBase) -> None: + self.__request: RestRequest = request + self.__root_resource: RestResourceBase = root_resource + + def user_login(self, user_name: str, user_secret: str) -> str: + from .rest_login import RestResourceBaseLogin + + if not isinstance(self.__root_resource, RestResourceBaseLogin): + raise RestResourceConfigException("root_resource must be RestResourceBaseLogin to use user_login") + return self.__root_resource.user_login(user_name, user_secret, self.__request) + + def get_user_login(self) -> str: + return self.__request.get_user().name + + def set_resp_cookie_value(self, key: str, value: str) -> None: + self.__request.set_resp_cookie_value(key, value) + + def reset_resp_cookie(self, key: str) -> None: + self.__request.reset_resp_cookie(key) + + def get_new_cookie_expiration_date(self) -> datetime: + from .rest_login import RestResourceBaseLogin + + if not isinstance(self.__root_resource, RestResourceBaseLogin): + raise RestResourceConfigException("root_resource must be RestResourceBaseLogin to use get_new_cookie_expiration_date") + return self.__root_resource.get_new_cookie_expiration_date() + + def set_resp_status(self, status: int) -> None: + self.__request.set_resp_status(status) + + +class ResourcePlugin_field(ResourcePlugin, Generic[TV_SupportedRESTFields]): + @abstractmethod + def handle_field_get(self, resource: TV_SupportedRESTFields, params: RestRequestParams_GET) -> TV_SupportedRESTFields: + ... + + @abstractmethod + def handle_field_put(self, resource: TV_SupportedRESTFields, params: RestRequestParams_PUT) -> TV_SupportedRESTFields: + ... + + +class ResourcePlugin_field_default(ResourcePlugin_field[TV_SupportedRESTFields]): + """default implementation of RestResourcePlugin_simple""" + + def handle_field_get(self, resource: TV_SupportedRESTFields, params: RestRequestParams_GET) -> TV_SupportedRESTFields: + return resource + + def handle_field_put(self, resource: TV_SupportedRESTFields, params: RestRequestParams_PUT) -> TV_SupportedRESTFields: + return resource + + +class ResourcePlugin_RestResourceBase(ResourcePlugin, Generic[TV_RestResourceBase]): + @abstractmethod + def handle_resource_get( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_RestResourceBase_GET, + ) -> TV_RestResourceBase: + ... + + @abstractmethod + def handle_resource_put( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_RestResourceBase_PUT, + ) -> TV_RestResourceBase: + ... + + +class ResourcePlugin_RestResourceBase_default(ResourcePlugin_RestResourceBase[TV_RestResourceBase]): + """default implementation of RestResourcePlugin_RestResourceBase""" + + def handle_resource_get( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_RestResourceBase_GET, + ) -> TV_RestResourceBase: + return resource + + def handle_resource_put( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_RestResourceBase_PUT, + ) -> TV_RestResourceBase: + return resource + + +class ResourcePlugin_dict(ResourcePlugin, Generic[_T_DictKey, _T_DictValues]): + @abstractmethod + def handle_dict_get_keys( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + params: RestRequestParams_Dict_GET, + ) -> list[_T_DictKey]: + ... + + @abstractmethod + def handle_dict_post( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + resource: _T_DictValues, + params: RestRequestParams_Dict_POST[_T_DictKey], + ) -> Optional[_T_DictKey]: + ... + + @abstractmethod + def handle_dict_delete( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + params: RestRequestParams_Dict_DELETE[_T_DictKey], + ) -> None: + ... + + @abstractmethod + def handle_dict_elem_get( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_Dict_elem_GET, + ) -> TV_RestResourceBase: + ... + + @abstractmethod + def handle_dict_elem_put( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_Dict_elem_PUT, + ) -> TV_RestResourceBase: + ... + + +class ResourcePlugin_dict_default(ResourcePlugin_dict[_T_DictKey, _T_DictValues]): + """default implementation of RestResourcePlugin_dict""" + + def handle_dict_get_keys( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + params: RestRequestParams_Dict_GET, + ) -> list[_T_DictKey]: + return list(resource_dict.keys()) + + def handle_dict_post( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + resource: _T_DictValues, + params: RestRequestParams_Dict_POST[_T_DictKey], + ) -> Optional[_T_DictKey]: + if params.API_key is not None: + resource_dict[params.API_key] = resource + return params.API_key + + def handle_dict_delete( + self, + resource_dict: dict[_T_DictKey, _T_DictValues], + params: RestRequestParams_Dict_DELETE[_T_DictKey], + ) -> None: + if params.API_key is not None: + del resource_dict[params.API_key] + + def handle_dict_elem_get( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_Dict_elem_GET, + ) -> TV_RestResourceBase: + return resource + + def handle_dict_elem_put( + self, + resource: TV_RestResourceBase, + params: RestRequestParams_Dict_elem_PUT, + ) -> TV_RestResourceBase: + return resource diff --git a/src/pyrestresource/rest_resource_rootpoint.py b/src/pyrestresource/rest_resource_rootpoint.py new file mode 100644 index 0000000..caec5cd --- /dev/null +++ b/src/pyrestresource/rest_resource_rootpoint.py @@ -0,0 +1,182 @@ +from __future__ import annotations +from typing import ( + cast, + get_args, + get_origin, + TYPE_CHECKING, +) + +from pydantic import BaseModel +from pydantic.fields import FieldInfo + +from .rest_resource import RestResourceBase +from .rest_resource_plugin import ( + ResourcePlugin_field, + ResourcePlugin_RestResourceBase, + ResourcePlugin_dict, +) +from .rest_resource_walker import ( + RestResourceWalker_Root, + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, +) +from .rest_types import rsrc_verb, _T_SupportedRESTFields +from .rest_ACL import ( + ACL_record, + ACL_target_group_Any, + ACL_rule, +) +from .rest_exceptions import RestResourcePluginException_InvalidPluginSignature, RestResourceModelException, RestResourceModelException_ACL + +if TYPE_CHECKING is True: + pass + + +class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): + def process(self) -> None: + datatype = get_args(self.annotation) + + # checking compatibility + if not get_origin(datatype[1]) is None: + raise RestResourceModelException("complex dict types are not supported (should create a RestResourceBase container)") + if not datatype[0] in _T_SupportedRESTFields: + raise RestResourceModelException(f"Unsupported Dict Field value type in class (key)") + + # preprocessing types / structure + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + self.parent.annotation._dict_key_type_[self.resource_name] = datatype[0] # pylint: disable=protected-access + self.parent.annotation._dict_value_type_[self.resource_name] = datatype[1] # pylint: disable=protected-access + self.parent.annotation._model_dump_excluded_[self.resource_name] = True # pylint: disable=protected-access + + assert isinstance(self.resource, FieldInfo) + current_resource = cast(FieldInfo, self.resource) + current_resource.exclude = True + + parent_resource = cast(type[RestResourceBase], self.parent.resource) + assert issubclass(parent_resource, RestResourceBase) + parent_resource.model_rebuild(force=True) + + self.parent.annotation._ACL_record_[self.resource_name] = [] + + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + ): + if "plugin" in self.resource.json_schema_extra: + plugin_dict: type[ResourcePlugin_dict] = self.resource.json_schema_extra["plugin"] + if not issubclass(plugin_dict, ResourcePlugin_dict): + raise RestResourcePluginException_InvalidPluginSignature() + self.parent.annotation._plugins_[self.resource_name] = plugin_dict + # print("ADD DICT PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + # print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] + else: + raise RestResourceModelException_ACL("ACL must be a list()") + + else: + raise RestResourceModelException("dict must be contained in a RestResourceBase") + + +class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFields): + def process(self) -> None: + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + # import pprint + + # print("1aaaaaaaaaa") + # pprint.pprint(self.resource.json_schema_extra) + # pprint.pprint(self.annotation) + # pprint.pprint(self.resource.exclude) + + self.parent.annotation._ACL_record_[self.resource_name] = [] + + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + ): + # print("aaaaaaaaaa") + + if "primary_key" in self.resource.json_schema_extra and self.resource.json_schema_extra["primary_key"] is True: + if self.parent.annotation._primary_key_ is not None: + raise RestResourceModelException( + f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}" + ) + self.parent.annotation._primary_key_ = self.resource_name + self.parent.annotation._ACL_record_[self.resource_name] = [ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY) + ] + + if "plugin" in self.resource.json_schema_extra: + plugin_field: type[ResourcePlugin_field] = self.resource.json_schema_extra["plugin"] + if not issubclass(plugin_field, ResourcePlugin_field): + raise RestResourcePluginException_InvalidPluginSignature() + self.parent.annotation._plugins_[self.resource_name] = plugin_field + # print("ADD FIELD PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + # print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] + else: + raise RestResourceModelException_ACL("ACL must be a list()") + + else: + raise RestResourceModelException("fields must be contained in a RestResourceBase") + + +class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_RestResourceBase): + def process(self) -> None: + setattr(self.annotation, "_dict_key_type_", {}) + setattr(self.annotation, "_dict_value_type_", {}) + setattr(self.annotation, "_model_dump_excluded_", {}) + setattr(self.annotation, "_primary_key_", None) + setattr(self.annotation, "_plugins_", {}) + setattr(self.annotation, "_ACL_record_", {}) + + # preprocessing types / structure + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + self.parent.annotation._model_dump_excluded_[self.resource_name] = True + assert isinstance(self.resource, FieldInfo) + current_resource = cast(FieldInfo, self.resource) + current_resource.exclude = True + parent_resource = cast(type[RestResourceBase], self.parent.resource) + assert issubclass(parent_resource, RestResourceBase) + parent_resource.model_rebuild(force=True) + self.parent.annotation._ACL_record_[self.resource_name] = [] + + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + ): + if "plugin" in self.resource.json_schema_extra: + plugin_resource: type[ResourcePlugin_RestResourceBase] = self.resource.json_schema_extra["plugin"] + if not issubclass(plugin_resource, ResourcePlugin_RestResourceBase): + raise RestResourcePluginException_InvalidPluginSignature() + self.parent.annotation._plugins_[self.resource_name] = plugin_resource + # print("ADD RESOURCE PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + # print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] + else: + raise RestResourceModelException_ACL("ACL must be a list()") + + +class RestResourceWalker_Root__tree_init(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict__tree_init, + RestResourceWalker_Sub_RestFields__tree_init, + RestResourceWalker_Sub_RestResourceBase__tree_init, + ] + + +def register_rest_rootpoint(klass: type[RestResourceBase]): + RestResourceWalker_Root__tree_init(klass).process() + return klass diff --git a/src/pyrestresource/rest_resource_walker.py b/src/pyrestresource/rest_resource_walker.py new file mode 100644 index 0000000..fc68aa9 --- /dev/null +++ b/src/pyrestresource/rest_resource_walker.py @@ -0,0 +1,256 @@ +from __future__ import annotations +from typing import ( + cast, + Union, + get_args, + get_origin, + TypeVar, + Type, + Generic, + TYPE_CHECKING, +) + +from abc import ABC, abstractmethod +from pydantic.fields import FieldInfo + +from .rest_types import _T_SupportedRESTFields +from .rest_resource import RestResourceBase +from .rest_exceptions import RestResourceModelException + +if TYPE_CHECKING is True: + from typing import Any, Optional + +TV_RestResourceWalkerFutureResult = TypeVar("TV_RestResourceWalkerFutureResult") + + +class RestResourceWalkerFutureResult(ABC, Generic[TV_RestResourceWalkerFutureResult]): + def __init__(self, source: RestResourceWalker_Sub): + self.source: RestResourceWalker_Sub = source + + def chain_process_future(self) -> Optional[TV_RestResourceWalkerFutureResult]: + return self.source.chain_process_future() + + @abstractmethod + def process_future(self, result: Optional[list[TV_RestResourceWalkerFutureResult]]) -> Optional[TV_RestResourceWalkerFutureResult]: + pass + + +class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): + cls_RestResourceWalkerFutureResult: Optional[type[RestResourceWalkerFutureResult[TV_RestResourceWalkerFutureResult]]] = None + + @classmethod + @abstractmethod + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: + """implementation interface to Factory. + The factory will call this specialized method on each implementation to find a supported one. + """ + ... + + @classmethod + def get( + self, + subs: list[type[RestResourceWalker_Sub]], + resource_name: str, + resource: FieldInfo | Type[RestResourceBase], + parent: Optional[RestResourceWalker_Sub] = None, + argument: Optional[Any] = None, + ) -> Optional[RestResourceWalker_Sub]: + for sub in subs: + _is_valid, _anno, _optional = sub.check_type(resource) + + if _is_valid is True: + return sub(resource_name, resource, parent, _anno, _optional, argument) + raise RestResourceModelException(f"Incompatible Field Found: {type(resource).__name__}") + return None + + def __init__( + self, + resource_name: str, + resource: FieldInfo | Type[RestResourceBase], + parent: Optional[RestResourceWalker_Sub] = None, + annotation: Optional[type[RestResourceBase]] = None, + _optional: Optional[bool] = None, + argument: Optional[Any] = None, + ): + self.argument: Any = argument + self.resource_name: str = resource_name + self.resource: FieldInfo | Type[RestResourceBase] = resource + self.parent: Optional[RestResourceWalker_Sub] = parent + + self.future_results_subs: Optional[list[RestResourceWalkerFutureResult[TV_RestResourceWalkerFutureResult]]] = None + self.future_result: Optional[RestResourceWalkerFutureResult[TV_RestResourceWalkerFutureResult]] = None + if self.cls_RestResourceWalkerFutureResult is not None: + self.future_results_subs = [] + self.future_result = self.cls_RestResourceWalkerFutureResult(self) + + self.annotation: type[RestResourceBase] + self.optional: bool + if annotation is None or _optional is None: + self.annotation, self.optional = self.ProcessAnnotation(resource) + else: + self.annotation = annotation + self.optional = _optional + + if self.annotation is None: + raise RestResourceModelException("Only annotated types are allowed in RestResourceBase derived classes") + + self.subdatatype = get_args(self.annotation) + + @abstractmethod + def get_future(self) -> Optional[RestResourceWalkerFutureResult]: + return self.future_result + + def chain_process_future(self) -> Optional[TV_RestResourceWalkerFutureResult]: + if self.future_results_subs is not None and self.future_result is not None: + return_future_results_subs: list[Any] = [] # TODO: use typevar + for future_result in self.future_results_subs: + return_future_results_subs.append(future_result.chain_process_future()) + return self.future_result.process_future(return_future_results_subs) + return None + + def collect_future_result( + self, + process_future_result: Optional[RestResourceWalkerFutureResult[TV_RestResourceWalkerFutureResult]], + ) -> None: + if process_future_result is not None and self.future_results_subs is not None and self.future_result is not None: + self.future_results_subs.append(process_future_result) + + # @abstractmethod + def get_sub_resources(self) -> list[tuple[str, FieldInfo]]: + return [] + + def process(self): + pass + + @staticmethod + def ProcessAnnotation( + resource: FieldInfo | Type[RestResourceBase], + ) -> tuple[type[Any], bool]: + # from .rest_resource import RestResourceBase + + _anno: Type[Any] + + # print("!!!!!!!!!!!!!!!!!!!!!!!") + # print(resource) + # print(type(resource)) + + if isinstance(resource, FieldInfo) and resource.annotation is not None: + _anno = resource.annotation + elif not isinstance(resource, FieldInfo) and issubclass(resource, RestResourceBase): + _anno = resource + else: + raise RestResourceModelException("Incompatible resource type") + + _datatype = get_args(_anno) + _optional: bool = False + if get_origin(_anno) is Union: + if len(_datatype) == 2: + if _datatype[0] is type(None): + _anno = _datatype[1] + _optional = True + elif _datatype[1] is type(None): + _anno = _datatype[0] + _optional = True + else: + raise RestResourceModelException("Union is only allowed to describe Optional (e.g. Union[XXX,None])") + + return _anno, _optional + + +class RestResourceWalker_Sub_T_Dict(RestResourceWalker_Sub): + @classmethod + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: + _anno, _optional = cls.ProcessAnnotation(resource) + _type_resource = get_origin(_anno) + return (_type_resource is dict), _anno, _optional + + def get_sub_resources(self) -> list[tuple[str, FieldInfo]]: + # print("????") + # print(self.subdatatype[1]) + return [(self.resource_name, self.subdatatype[1])] + + def get_future(self) -> Optional[RestResourceWalkerFutureResult]: + return self.future_result + + +class RestResourceWalker_Sub_RestFields(RestResourceWalker_Sub): + @classmethod + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: + _anno, _optional = cls.ProcessAnnotation(resource) + return (_anno in _T_SupportedRESTFields), _anno, _optional + + def get_future(self) -> Optional[RestResourceWalkerFutureResult]: + return self.future_result + + +class RestResourceWalker_Sub_RestResourceBase(RestResourceWalker_Sub): + @classmethod + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: + _anno, _optional = cls.ProcessAnnotation(resource) + return ( + ((get_origin(_anno) is None) and issubclass(_anno, RestResourceBase)), + _anno, + _optional, + ) + + def get_sub_resources(self) -> list[tuple[str, FieldInfo]]: + return [(cast(str, key), attr) for key, attr in self.annotation.model_fields.items()] + + def get_future(self) -> Optional[RestResourceWalkerFutureResult]: + return self.future_result + + +class RestResourceWalker_Root: + cls_RestResourceWalker_Sub: list[Type[RestResourceWalker_Sub]] = [ + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, + ] + + def __init__(self, resource: RestResourceBase | Type[RestResourceBase]) -> None: + self.subwalker_argument: Any = None + self.resource: Type[RestResourceBase] + if isinstance(resource, RestResourceBase): + self.resource = type(resource) + else: + self.resource = resource + + def process(self, argument: Optional[Any] = None, deep_limit: Optional[int] = None) -> Optional[TV_RestResourceWalkerFutureResult]: + current_deep: int = 0 + + sub_walker_initial: Optional[RestResourceWalker_Sub] = RestResourceWalker_Sub.get( + self.cls_RestResourceWalker_Sub, "/", self.resource, None, argument + ) + + if sub_walker_initial is not None: + sub_walker_initial.process() + sub_walker_initial.get_future() + resource_list: list[tuple[str, FieldInfo | Type[RestResourceBase], RestResourceWalker_Sub]] = [ + (subresource_name, subresource, sub_walker_initial) + for subresource_name, subresource in sub_walker_initial.get_sub_resources() + ] + + new_resource_list: list[tuple[str, FieldInfo, RestResourceWalker_Sub]] + sub_walker: Optional[RestResourceWalker_Sub] + while len(resource_list) > 0 and (deep_limit is None or current_deep < deep_limit): + new_resource_list = [] + for resource_name, resource, parent_sub_walker in resource_list: + sub_walker = RestResourceWalker_Sub.get( + self.cls_RestResourceWalker_Sub, resource_name, resource, parent_sub_walker, argument + ) + if sub_walker is not None: + sub_walker.process() + process_future_result: Optional[RestResourceWalkerFutureResult] = sub_walker.get_future() + parent_sub_walker.collect_future_result(process_future_result) + new_resource_list.extend( + [ + (subresource_name, subresource, sub_walker) + for subresource_name, subresource in sub_walker.get_sub_resources() + ] + ) + + resource_list = list(new_resource_list) + current_deep = current_deep + 1 + return sub_walker_initial.chain_process_future() + else: + raise RestResourceModelException("Invalid Rootpoint") diff --git a/src/pyrestresource/rest_types.py b/src/pyrestresource/rest_types.py new file mode 100644 index 0000000..8898690 --- /dev/null +++ b/src/pyrestresource/rest_types.py @@ -0,0 +1,101 @@ +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring +from __future__ import annotations +from typing import Union, get_origin, NewType, TypeVar, TYPE_CHECKING + +from enum import Enum, auto +from datetime import datetime +from pathlib import Path +from uuid import UUID +from ipaddress import IPv4Address, IPv4Network + +if TYPE_CHECKING is True: + from .rest_resource import RestResourceBase + +T_Gen_DictKeys: type = type({}.keys()) +NoneType = type(None) + + +class rsrc_verb(Enum): + GET = auto() + PUT = auto() + POST = auto() + DELETE = auto() + + +class rsrc_type(Enum): + resource = auto() + dict = auto() + list = auto() + field = auto() + + +_T_SupportedRESTFields = [ + UUID, + str, + int, + float, + bool, + bytes, + datetime, + Path, + IPv4Address, + IPv4Network, + NoneType, +] +T_SupportedRESTFields = Union[UUID, str, int, float, bool, bytes, datetime, Path, IPv4Address, IPv4Network, NoneType] +TV_SupportedRESTFields = TypeVar( + "TV_SupportedRESTFields", + UUID, + str, + int, + float, + bool, + bytes, + datetime, + Path, + IPv4Address, + IPv4Network, + NoneType, +) + +assert get_origin(T_SupportedRESTFields) is Union + +TV_RestResourceBase = TypeVar("TV_RestResourceBase", bound="RestResourceBase") + +T_FieldValue = Union[T_SupportedRESTFields, "RestResourceBase"] + + +T_ListIndex = NewType("T_ListIndex", int) +T_ListSize = NewType("T_ListSize", int) + +T_DictKey = Union[UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network] # datetime is removed because non-hashable +_T_DictKey = TypeVar("_T_DictKey", UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network) + +T_T_DictKey = type[T_DictKey] + + +T_DictValues = T_FieldValue +_T_DictValues = TypeVar( + "_T_DictValues", + UUID, + str, + int, + float, + bool, + bytes, + datetime, + Path, + IPv4Address, + IPv4Network, + "RestResourceBase", + NoneType, +) + +T_T_FieldValue = type(T_FieldValue) +T_T_DictValues = type[T_DictValues] + +T_Dict = dict[T_DictKey, T_DictValues] +_T_Dict = dict[_T_DictKey, _T_DictValues] + +T_AllSupportedFields = T_Dict | T_FieldValue +T_AllSupportedContainers = Union[T_Dict, "RestResourceBase"] diff --git a/src/pyrestresource/test_module.py b/src/pyrestresource/test_module.py deleted file mode 100644 index e0275df..0000000 --- a/src/pyrestresource/test_module.py +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# pyrestresource (c) by chacha -# -# pyrestresource is licensed under a -# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. -# -# You should have received a copy of the license along with this -# work. If not, see . - -"""Phasellus tellus lectus, volutpat eu dapibus ut, suscipit vel augue. - -Tips: - Aliquam non leo vel libero sagittis viverra. Quisque lobortis nunc sit amet augue euismod laoreet. -Note: - Maecenas volutpat porttitor pretium. Aliquam suscipit quis nisi non imperdiet. -Note: - Vivamus et efficitur lorem, eget imperdiet tortor. Integer vel interdum sem. -""" - -from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: # Only imports the below statements during type checking - pass - -def test_function(testvar: int) -> int: - """ A test function that return testvar+1 and print "Hello world !" - - Proin eget sapien eget ipsum efficitur mollis nec ac nibh. - - Note: - Morbi id lectus maximus, condimentum nunc eget, porta felis. In tristique velit tortor. - - Args: - testvar: any integer - - Returns: - testvar+1 - """ - print("Hello world !") - return testvar+1 diff --git a/test/ThreadedUvicorn.py b/test/ThreadedUvicorn.py new file mode 100644 index 0000000..b47d625 --- /dev/null +++ b/test/ThreadedUvicorn.py @@ -0,0 +1,23 @@ +from uvicorn import Config, Server +from threading import Thread +import asyncio + + +class ThreadedUvicorn: + def __init__(self, config: Config): + self.server = Server(config) + self.thread = Thread(daemon=True, target=self.server.run) + + def start(self): + self.thread.start() + asyncio.run(self.wait_for_started()) + + async def wait_for_started(self): + while not self.server.started: + await asyncio.sleep(0.1) + + def stop(self): + if self.thread.is_alive(): + self.server.should_exit = True + while self.thread.is_alive(): + continue diff --git a/test/__init__.py b/test/__init__.py index 8a7f597..006fd7e 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -4,4 +4,6 @@ # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. # # You should have received a copy of the license along with this -# work. If not, see . \ No newline at end of file +# work. If not, see . + +from .ThreadedUvicorn import ThreadedUvicorn diff --git a/test/test_ACL.py b/test/test_ACL.py new file mode 100644 index 0000000..ecf3483 --- /dev/null +++ b/test/test_ACL.py @@ -0,0 +1,178 @@ +from __future__ import annotations +import unittest +from os import chdir +from pathlib import Path +from typing import Optional + +from src.pyrestresource import ( + RestField, + RestResourceHandlerException_Forbiden, + register_rest_rootpoint, + RestResourceBase, + rsrc_verb, + RestRequestParams_GET, + RestRequestParams_POST, + RestRequestParams_Dict_GET, + RestRequestParams_PUT, + T_SupportedRESTFields, + ResourcePlugin_field_default, + ResourcePlugin_RestResourceBase_default, + ACL_target_group_Any, + ACL_record, + ACL_rule, +) + + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +# to allow mock-ing, all the tested classes are in a function +def init_classes(): + class TestResource(RestResourceBase): + username: Optional[str] = RestField(None) + secret: Optional[str] = RestField( + None, + exclude=True, + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + + class TestResource2(RestResourceBase): + version_ro: Optional[str] = RestField( + "1.2.3", + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + version: Optional[str] = RestField("3.2.1") + + @register_rest_rootpoint + class RootApp(RestResourceBase): + resource_with_secret: TestResource = RestField(default=TestResource()) + resource_with_secret_ACL: TestResource = RestField( + default=TestResource(), ACL=[ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY)] + ) + resource_ro: TestResource2 = RestField(TestResource2()) + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[TestResource.__name__] = TestResource + globals()[RootApp.__name__] = RootApp + + +class Test_RestAPI_ACL(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_subresource_readonly(self): + result = self.testapp.process_request("/", rsrc_verb.GET) + self.assertEqual(result.get_result(), "{}") + + result = self.testapp.process_request("/resource_ro", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "3.2.1"}') + + self.testapp.process_request("/resource_ro/version", rsrc_verb.PUT, '"6.6.6"') + + result = self.testapp.process_request("/resource_ro", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') + + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception + self.testapp.process_request("/resource_ro/version_ro", rsrc_verb.PUT, '"6.6.6"') + self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3") + + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception + self.testapp.process_request("/resource_ro", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}') + self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3") + + result = self.testapp.process_request("/resource_ro", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') + + def test_subresource(self): + result = self.testapp.process_request("/", rsrc_verb.GET) + self.assertEqual(result.get_result(), "{}") + + result = self.testapp.process_request("/resource_with_secret", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"username": null}') + + result = self.testapp.process_request("/resource_with_secret/username", rsrc_verb.GET) + self.assertEqual(result.get_result(), "null") + self.assertEqual(self.testapp.resource_with_secret.username, None) + + with self.assertRaises(RestResourceHandlerException_Forbiden): + self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) + + self.assertEqual(self.testapp.resource_with_secret.secret, None) + + result = self.testapp.process_request("/resource_with_secret", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') + self.assertEqual(result.get_result(), "null") + + result = self.testapp.process_request("/resource_with_secret", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"username": "chacha"}') + + result = self.testapp.process_request("/resource_with_secret/username", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"chacha"') + self.assertEqual(self.testapp.resource_with_secret.username, "chacha") + + with self.assertRaises(RestResourceHandlerException_Forbiden): + self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) + + self.assertEqual(self.testapp.resource_with_secret.secret, "123456") + + def test_subresource_field(self): + result = self.testapp.process_request("/resource_with_secret/username", rsrc_verb.PUT, '"chacha"') + self.assertEqual(result.get_result(), "null") + + result = self.testapp.process_request("/resource_with_secret", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"username": "chacha"}') + + result = self.testapp.process_request("/resource_with_secret/username", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"chacha"') + self.assertEqual(self.testapp.resource_with_secret.username, "chacha") + + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception + self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) + + result = self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.PUT, '"123456"') + self.assertEqual(result.get_result(), "null") + + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception + self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) + + self.assertEqual(self.testapp.resource_with_secret.secret, "123456") + + def test_subresource_ACL(self): + result = self.testapp.process_request("/", rsrc_verb.GET) + self.assertEqual(result.get_result(), "{}") + + result = self.testapp.process_request("/resource_with_secret_ACL", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"username": null}') + + result = self.testapp.process_request("/resource_with_secret_ACL/username", rsrc_verb.GET) + self.assertEqual(result.get_result(), "null") + self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) + + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception + self.testapp.process_request("/resource_with_secret_ACL/secret", rsrc_verb.GET) + + self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) + + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception + self.testapp.process_request("/resource_with_secret_ACL", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') + self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) + self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) + + def test_subresource_ACL_field(self): + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception + self.testapp.process_request("/resource_with_secret_ACL/username", rsrc_verb.PUT, '"chacha"') + self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) + self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) + + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception + self.testapp.process_request("/resource_with_secret_ACL/secret", rsrc_verb.PUT, '"123456"') + self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) + self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) diff --git a/test/test_rest_login.py b/test/test_rest_login.py new file mode 100644 index 0000000..8639cfc --- /dev/null +++ b/test/test_rest_login.py @@ -0,0 +1,625 @@ +from __future__ import annotations +import unittest +from os import chdir +from pathlib import Path +from typing import Optional, ClassVar +from time import sleep +import uvicorn +import socket +import requests +from contextlib import closing +from multiprocessing import Process +from requests.adapters import HTTPAdapter + +from src.pyrestresource import ( + RestField, + ACL_target_user, + UserLogin, + RestResourceBase, + RestResourceBaseLogin, + register_rest_rootpoint, + rsrc_verb, + RestRequestParams_GET, + RestRequestParams_POST, + RestRequestParams_Dict_GET, + RestRequestParams_PUT, + T_SupportedRESTFields, + ResourcePlugin_field_default, + ResourcePlugin_RestResourceBase_default, + ACL_target_group_Any, + ACL_record, + ACL_rule, +) + + +from test import ThreadedUvicorn + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +# to allow mock-ing, all the tested classes are in a function +def init_classes(): + user_test = UserLogin(username="TestUser", secret="123456") + user_test2 = UserLogin(username="TestUser2", secret="abcdef") + + class TestResource(RestResourceBase): + test_field: Optional[str] = RestField("ORIGIN_VALUE") + + class TestResourceACL(RestResourceBase): + test_field: Optional[str] = RestField( + "ORIGIN_VALUE", + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + test_field2: Optional[str] = RestField( + "ORIGIN_VALUE", + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test2), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + test_field_both: Optional[str] = RestField( + "ORIGIN_VALUE", + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test2), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + + @register_rest_rootpoint + class RootApp(RestResourceBaseLogin): + _ar_user_login: ClassVar[list[UserLogin]] = [user_test, user_test2] + test_resourceACL: TestResource = RestField( + TestResource(), + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user(name=user_test.username), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + test_resource: TestResourceACL = TestResourceACL() + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[TestResourceACL.__name__] = TestResourceACL + globals()[RootApp.__name__] = RootApp + + +def find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return "localhost", s.getsockname()[1] + + +class Test_RestAPI_LOGIN_Web(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + + def test_login_two_users(self): + ip, port = find_free_port() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() + sleep(1) + + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + # unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field2", json="TEST SET VALUE") + self.assertEqual(response.status_code, 403) + + # not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field2", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field_both", json="TEST SET VALUE 2") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field_both", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") + + # --------------------------------------- + # login 2 + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser2", "secret": "abcdef"}, + ) + self.assertEqual(response.status_code, 201) + + # unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="A TEST SET VALUE") + self.assertEqual(response.status_code, 403) + + # not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field2", json="A TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field2", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "A TEST SET VALUE") + + # previous (modified) value + response = s.get( + f"http://{ip}:{port}/test_resource/test_field_both", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field_both", json="A TEST SET VALUE 2") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field_both", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "A TEST SET VALUE 2") + + finally: + s.close() + server.stop() + + def test_login(self): + ip, port = find_free_port() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "TestUser"}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TestUser") + + finally: + s.close() + server.stop() + + def test_change_host(self): + ip, port = find_free_port() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() + sleep(1) + s1 = requests.Session() + s1.mount("http://", HTTPAdapter(max_retries=0)) + s2 = requests.Session() + s2.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # s1 - read full login resource + response = s1.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s1 - read login username field + response = s1.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # s2 - read full login resource + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s2 - read login username field + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # login s1 + response = s1.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # s1 - read full login resource + response = s1.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "TestUser"}) + + # s1 - read login username field + response = s1.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TestUser") + + # s2 - read full login resource + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s2 - read login username field + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # s2 -> spoof s1 token + s2.cookies.update(s1.cookies) + + # s2 - read full login resource + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s2.cookies.get_dict(), {}) + + # s2 - read full login resource (reseted cookie) + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s2 -> spoof s1 token + s2.cookies.update(s1.cookies) + + # s2 - read login username field + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s2.cookies.get_dict(), {}) + + # s2 - read full login resource (reseted cookie) + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + finally: + s1.close() + s2.close() + server.stop() + + def test_login_wrong_pwd(self): + ip, port = find_free_port() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + self.assertDictEqual(s.cookies.get_dict(), {}) + + # --------------------------------------------------- + # login (wrong pwd) + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "abc"}, + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + self.assertDictEqual(s.cookies.get_dict(), {}) + + # --------------------------------------------------- + # login (ok pwd) + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + self.assertTrue("Authorization" in response.cookies) + self.assertTrue("Authorization" in s.cookies.get_dict()) + self.assertTrue(s.cookies.get_dict()["Authorization"]) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "TestUser"}) + self.assertTrue("Authorization" in s.cookies.get_dict()) + self.assertTrue(s.cookies.get_dict()["Authorization"]) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TestUser") + self.assertTrue("Authorization" in s.cookies.get_dict()) + self.assertTrue(s.cookies.get_dict()["Authorization"]) + + # --------------------------------------------------- + # login (wrong pwd, after success) + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "abc"}, + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + self.assertDictEqual(s.cookies.get_dict(), {}) + + finally: + s.close() + server.stop() + + def test_access_resourceACL(self): + ip, port = find_free_port() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # before modification read + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resourceACL/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 403) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resourceACL", json={"test_field": "TEST SET VALUE"}) + self.assertEqual(response.status_code, 403) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resourceACL/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + # authenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resourceACL", json={"test_field": "TEST SET VALUE 2"}) + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") + + finally: + s.close() + server.stop() + + def test_access_fieldACL(self): + ip, port = find_free_port() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # before modification read + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 403) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resource", json={"test_field": "TEST SET VALUE"}) + self.assertEqual(response.status_code, 403) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + # authenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resource", json={"test_field": "TEST SET VALUE 2"}) + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") + + finally: + s.close() + server.stop() diff --git a/test/test_rest_resource.py b/test/test_rest_resource.py new file mode 100644 index 0000000..e169ce0 --- /dev/null +++ b/test/test_rest_resource.py @@ -0,0 +1,569 @@ +from __future__ import annotations +import unittest +from unittest.mock import patch +from os import chdir +from pathlib import Path +from typing import Optional +from uuid import UUID, uuid4 +from time import time +import json + + +print(__name__) +print(__package__) + +from src.pyrestresource import ( + RestField, + RestResourceHandlerException_Forbiden, + register_rest_rootpoint, + RestResourceBase, + rsrc_verb, + RestRequestParams_GET, + RestRequestParams_POST, + RestRequestParams_Dict_GET, + T_SupportedRESTFields, + ACL_target_group_Any, + ACL_record, + ACL_rule, +) +from pprint import pprint + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +# to allow mock-ing, all the tested classes are in a function +def init_classes(): + class Info(RestResourceBase): + version: str + api_version: str + + class Patch(RestResourceBase): + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + class Profile(RestResourceBase): + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + class Game(RestResourceBase): + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + profiles: dict[UUID, Profile] = {} + patchs: dict[UUID, Patch] = {} + + Patch_1 = Patch(uuid="cee1e870-65fa-11ee-8c99-0242ac120002", shortname="testPatch1") + Patch_2 = Patch(uuid="d385a1d2-65fa-11ee-8c99-0242ac120002", shortname="testPatch2") + + class User(RestResourceBase): + uuid: UUID = RestField( + default_factory=uuid4, + primary_key=True, + ) + name: str + secret: str = RestField( + ..., + exclude=True, + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + + User1 = User( + uuid="8da57a3c-661f-11ee-8c99-0242ac120002", + name="chacha", + secret="la blanquette est bonne", + ) + + class Patch2(RestResourceBase): + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + @register_rest_rootpoint + class RootApp(RestResourceBase): + testValueRoot: float = 3.14 + info: Info = Info(version="0.0.1", api_version="0.0.2") + games: dict[UUID, Game] = { + UUID("9b0381d4-65f6-11ee-8c99-0242ac120002"): Game( + uuid="9b0381d4-65f6-11ee-8c99-0242ac120002", + shortname="testGame", + patchs={Patch_1.uuid: Patch_1}, + profiles={ + UUID("aee1e870-65fa-11ee-8c99-0242ac120002"): Profile( + uuid="aee1e870-65fa-11ee-8c99-0242ac120002", + shortname="testprofile", + ) + }, + ) + } + patchs: dict[UUID, Patch] = {Patch_1.uuid: Patch_1, Patch_2.uuid: Patch_2} + users: dict[UUID, User] = {User1.uuid: User1} + + patchs2: dict[UUID, Patch2] = {} + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[Info.__name__] = Info + globals()[Game.__name__] = Game + globals()[User.__name__] = User + globals()[Profile.__name__] = Profile + globals()[Patch.__name__] = Patch + globals()[Patch2.__name__] = Patch2 + globals()[RootApp.__name__] = RootApp + + +class Test_RestAPI_GET(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_get_root(self): + result = self.testapp.process_request("/", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"testValueRoot": 3.14}') + + def test_get_root__multiple_slash(self): + result = self.testapp.process_request("/////", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"testValueRoot": 3.14}') + + result = self.testapp.process_request("////", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"testValueRoot": 3.14}') + + def test_get_root__nested_value(self): + result = self.testapp.process_request("/testValueRoot", rsrc_verb.GET) + self.assertEqual(result.get_result(), "3.14") + + def test_get_root__nested_value__trailing_slash(self): + result = self.testapp.process_request("/testValueRoot/", rsrc_verb.GET) + self.assertEqual(result.get_result(), "3.14") + + result = self.testapp.process_request("/testValueRoot//", rsrc_verb.GET) + self.assertEqual(result.get_result(), "3.14") + + result = self.testapp.process_request("/testValueRoot///", rsrc_verb.GET) + self.assertEqual(result.get_result(), "3.14") + + def test_get_root__nested_value__multiple_slash(self): + result = self.testapp.process_request("//testValueRoot", rsrc_verb.GET) + self.assertEqual(result.get_result(), "3.14") + + result = self.testapp.process_request("///testValueRoot", rsrc_verb.GET) + self.assertEqual(result.get_result(), "3.14") + + def test_get_version(self): + result = self.testapp.process_request("/info", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') + + def test_get_version__trailing_slash(self): + result = self.testapp.process_request("/info/", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') + + result = self.testapp.process_request("/info//", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') + + result = self.testapp.process_request("/info///", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') + + def test_get_version__multiple_slash(self): + result = self.testapp.process_request("//info", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') + + result = self.testapp.process_request("///info", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "0.0.1", "api_version": "0.0.2"}') + + def test_get_version__nested_value(self): + result = self.testapp.process_request("/info/api_version", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"0.0.2"') + + result = self.testapp.process_request("/info/version", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"0.0.1"') + + def test_get_dict_games(self): + result = self.testapp.process_request("/games", rsrc_verb.GET) + self.assertEqual(result.get_result(), '["9b0381d4-65f6-11ee-8c99-0242ac120002"]') + + def test_get_dict_patchs(self): + result = self.testapp.process_request("/patchs", rsrc_verb.GET) + self.assertEqual( + result.get_result(), + '["cee1e870-65fa-11ee-8c99-0242ac120002", "d385a1d2-65fa-11ee-8c99-0242ac120002"]', + ) + + def test_get_dict_patch_element(self): + result = self.testapp.process_request("/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", rsrc_verb.GET) + self.assertEqual( + result.get_result(), + '{"uuid": "cee1e870-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch1", "name": null, "description": null}', + ) + + def test_get_dict_game_element(self): + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002", rsrc_verb.GET) + expected = '{"uuid": "9b0381d4-65f6-11ee-8c99-0242ac120002", "shortname": "testGame", "name": null, "description": null}' + self.assertEqual(result.get_result(), expected) + + def test_get_dict_game_element__nested_value(self): + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname", rsrc_verb.GET) + expected = '"testGame"' + self.assertEqual(result.get_result(), expected) + + def test_get_dict_game_element__nested_value2(self): + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/uuid", rsrc_verb.GET) + expected = '"9b0381d4-65f6-11ee-8c99-0242ac120002"' + self.assertEqual(result.get_result(), expected) + + def test_get_nested_dict_games_patchs(self): + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) + self.assertEqual(result.get_result(), '["cee1e870-65fa-11ee-8c99-0242ac120002"]') + + def test_get_nested_dict_games_patch_element(self): + result = self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.GET, + ) + expected = '{"uuid": "cee1e870-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch1", "name": null, "description": null}' + self.assertEqual(result.get_result(), expected) + + def test_get_nested_dict_games_patch_element__nested_value(self): + result = self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/uuid", + rsrc_verb.GET, + ) + self.assertEqual(result.get_result(), '"cee1e870-65fa-11ee-8c99-0242ac120002"') + + def test_get_dict_game_element__API_nested(self): + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002?API_nested=True", rsrc_verb.GET) + expected = '{"uuid": "9b0381d4-65f6-11ee-8c99-0242ac120002", "shortname": "testGame", "name": null, "description": null}' + self.assertEqual(result.get_result(), expected) + + def test_get_dict_users(self): + result = self.testapp.process_request("/users", rsrc_verb.GET) + self.assertEqual(result.get_result(), '["8da57a3c-661f-11ee-8c99-0242ac120002"]') + + def test_get_dict_user_element(self): + result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.GET) + self.assertEqual( + result.get_result(), + '{"uuid": "8da57a3c-661f-11ee-8c99-0242ac120002", "name": "chacha"}', + "no secret seen", + ) + + def test_get_dict_user_element2(self): + result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002?API_nested=True", rsrc_verb.GET) + self.assertEqual( + result.get_result(), + '{"uuid": "8da57a3c-661f-11ee-8c99-0242ac120002", "name": "chacha"}', + "no secret seen", + ) + + def test_get_dict_user_element__nested_value(self): + result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002/name", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"chacha"') + + def test_get_dict_user_element__nested_value__forbiden(self): + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception + self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002/secret", rsrc_verb.GET) + + def test_get_dict_user_element__nested_value__forbiden2(self): + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception + self.testapp.process_request( + "/users/8da57a3c-661f-11ee-8c99-0242ac120002/secret?API_nested=True", + rsrc_verb.GET, + ) + + +class Test_RestAPI_PUT(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_put_info(self): + self.testapp.process_request("/info", rsrc_verb.PUT, '{"version": "1.2.3", "api_version": "3.2.1"}') + + result = self.testapp.process_request("/info", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.2.3", "api_version": "3.2.1"}') + + def test_put_dict_user_nested_value(self): + self.testapp.process_request( + "/users/8da57a3c-661f-11ee-8c99-0242ac120002/name", + rsrc_verb.PUT, + '"chacha2"', + ) + + result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002/name", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"chacha2"') + + def test_put_user_nested_value__forbiden(self): + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception + self.testapp.process_request( + "/users/8da57a3c-661f-11ee-8c99-0242ac120002/uuid", + rsrc_verb.PUT, + '"test"', + ) + + def test_put_dict_user_element(self): + self.testapp.process_request( + "/users/8da57a3c-661f-11ee-8c99-0242ac120002", + rsrc_verb.PUT, + '{"name": "testUser4", "secret": "test5"}', + ) + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = '["8da57a3c-661f-11ee-8c99-0242ac120002"]' + self.assertEqual(result.get_result(), expected) + + result = self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.GET) + expected = '{"uuid": "8da57a3c-661f-11ee-8c99-0242ac120002", "name": "testUser4"}' + self.assertEqual(result.get_result(), expected) + + def test_put_dict_patch__nested(self): + self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.PUT, + '{"shortname": "testPatch998", "name": "MyPatch", "description": "MyDescription123"}', + ) + + result = self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.GET, + ) + expected = '{"uuid": "cee1e870-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch998", "name": "MyPatch", "description": "MyDescription123"}' + self.assertEqual(result.get_result(), expected) + + +class Test_RestAPI_POST(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_post_dict_user__API_key(self): + result = self.testapp.process_request( + "/users?API_key=e5e87d32-662b-11ee-8c99-0242ac120002", + rsrc_verb.POST, + '{"name": "testUser", "secret": "test"}', + ) + self.assertEqual(result.get_result(), '"e5e87d32-662b-11ee-8c99-0242ac120002"') + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "e5e87d32-662b-11ee-8c99-0242ac120002"]' + self.assertEqual(result.get_result(), expected) + + result = self.testapp.process_request("/users/e5e87d32-662b-11ee-8c99-0242ac120002", rsrc_verb.GET) + expected = '{"uuid": "e5e87d32-662b-11ee-8c99-0242ac120002", "name": "testUser"}' + self.assertEqual(result.get_result(), expected) + + def test_post_dict_user__nested_key(self): + result = self.testapp.process_request( + "/users", + rsrc_verb.POST, + '{"name": "testUser2", "secret": "test", "uuid":"e7e86d32-662b-11ee-8c99-0242ac120002"}', + ) + self.assertEqual(result.get_result(), '"e7e86d32-662b-11ee-8c99-0242ac120002"') + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "e7e86d32-662b-11ee-8c99-0242ac120002"]' + self.assertEqual(result.get_result(), expected) + + result = self.testapp.process_request("/users/e7e86d32-662b-11ee-8c99-0242ac120002", rsrc_verb.GET) + expected = '{"uuid": "e7e86d32-662b-11ee-8c99-0242ac120002", "name": "testUser2"}' + self.assertEqual(result.get_result(), expected) + + @patch(f"{__loader__.name }.uuid4") + def test_post_dict_user__auto_key(self, mock_uuid4): + mock_uuid4.return_value = UUID("5faccb2e-69aa-11ee-8c99-0242ac120002") + + # recreating classes & objects to force using the Mock-ed uuid4 + init_classes() + self.testapp = RootApp() + + result = self.testapp.process_request("/users", rsrc_verb.POST, '{"name": "testUser3", "secret": "test"}') + self.assertEqual(result.get_result(), '"5faccb2e-69aa-11ee-8c99-0242ac120002"') + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "5faccb2e-69aa-11ee-8c99-0242ac120002"]' + self.assertEqual(result.get_result(), expected) + + result = self.testapp.process_request("/users/5faccb2e-69aa-11ee-8c99-0242ac120002", rsrc_verb.GET) + expected = '{"uuid": "5faccb2e-69aa-11ee-8c99-0242ac120002", "name": "testUser3"}' + self.assertEqual(result.get_result(), expected) + + def test_post_dict_patch__nested_API_key(self): + self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs?API_key=cee1e971-65fa-11ee-8c99-0242ac120002", + rsrc_verb.POST, + '{"shortname": "testPatch99", "name": "MyPatch", "description": "MyDescription"}', + ) + + result = self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e971-65fa-11ee-8c99-0242ac120002", + rsrc_verb.GET, + ) + expected = '{"uuid": "cee1e971-65fa-11ee-8c99-0242ac120002", "shortname": "testPatch99", "name": "MyPatch", "description": "MyDescription"}' + self.assertEqual(result.get_result(), expected) + + +class Test_RestAPI_DELETE(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_delete_dict_user__API_key(self): + self.testapp.process_request("/users?API_key=8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.DELETE) + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result.get_result(), expected) + + def test_delete_dict_user__All(self): + result = self.testapp.process_request( + "/users?API_key=e5e87d32-662b-11ee-8c99-0242ac120002", + rsrc_verb.POST, + '{"name": "testUser", "secret": "test"}', + ) + self.assertEqual(result.get_result(), '"e5e87d32-662b-11ee-8c99-0242ac120002"') + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = '["8da57a3c-661f-11ee-8c99-0242ac120002", "e5e87d32-662b-11ee-8c99-0242ac120002"]' + self.assertEqual(result.get_result(), expected) + + self.testapp.process_request("/users", rsrc_verb.DELETE) + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result.get_result(), expected) + + def test_delete_dict_user_element(self): + self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.DELETE) + + result = self.testapp.process_request("/users", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result.get_result(), expected) + + def test_delete_nested_dict_games_patch_element(self): + self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.DELETE, + ) + + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result.get_result(), expected) + + def test_delete_nested_dict_games_patch_API_key(self): + self.testapp.process_request( + "/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs?API_key=cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.DELETE, + ) + + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result.get_result(), expected) + + def test_delete_nested_dict_games_patch_All(self): + self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.DELETE) + + result = self.testapp.process_request("/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs", rsrc_verb.GET) + expected = "[]" + self.assertEqual(result.get_result(), expected) + + +class Test_RestAPI_PERFO(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + @unittest.skip + def test_perf_dict(self): + print(f"LIB INTERNAL PERF TEST") + n_loop = 10000 + + start = time() + for _ in range(n_loop): + self.testapp.process_request(f"/users/8da57a3c-661f-11ee-8c99-0242ac120002", rsrc_verb.GET) + end = time() + print(f"GET 1st level dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + newUUID = uuid4() + self.testapp.process_request( + f"/users?API_key={newUUID}", + rsrc_verb.POST, + '{"name": "testUser", "secret": "test"}', + ) + end = time() + print(f"POST 1st level dict (API_key): {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + newUUID = uuid4() + self.testapp.process_request( + f"/users?API_key={newUUID}", + rsrc_verb.POST, + '{"name": "testUser", "secret": "test"}', + ) + self.testapp.process_request(f"/users/{newUUID}", rsrc_verb.GET) + end = time() + print(f"POST/GET 1st level dict (API_key): {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + result = self.testapp.process_request(f"/users", rsrc_verb.POST, '{"name": "testUser", "secret": "test"}') + self.testapp.process_request(f"/users/{json.loads(result.get_result())}", rsrc_verb.GET) + end = time() + print(f"POST/GET 1st level dict (autokey): {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + self.testapp.process_request( + f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname", + rsrc_verb.PUT, + '"TestValue!!"', + ) + self.testapp.process_request(f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname", rsrc_verb.GET) + end = time() + print(f"PUT/GET 1st level (value) dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + self.testapp.process_request( + f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002", + rsrc_verb.GET, + ) + end = time() + print(f"GET 2nd level dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + self.testapp.process_request( + f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + rsrc_verb.GET, + ) + end = time() + print(f"GET 2nd level (value) dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + self.testapp.process_request( + f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + rsrc_verb.PUT, + '"TestValue!!"', + ) + self.testapp.process_request( + f"/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + rsrc_verb.GET, + ) + end = time() + print(f"PUT/GET 2nd level (value) dict: {int(n_loop/(end-start))} Req/s") diff --git a/test/test_rest_resource_plugins.py b/test/test_rest_resource_plugins.py new file mode 100644 index 0000000..eecc27f --- /dev/null +++ b/test/test_rest_resource_plugins.py @@ -0,0 +1,174 @@ +from __future__ import annotations +import unittest +from os import chdir +from pathlib import Path +from typing import Annotated + +from src.pyrestresource import ( + RestField, + register_rest_rootpoint, + RestResourceBase, + rsrc_verb, + RestRequestParams_GET, + RestRequestParams_POST, + RestRequestParams_Dict_GET, + RestRequestParams_PUT, + T_SupportedRESTFields, + ResourcePlugin_field_default, + ResourcePlugin_RestResourceBase_default, + RestResourcePluginException_InvalidPluginSignature, +) + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +# to allow mock-ing, all the tested classes are in a function +def init_classes(): + class ResourcePlugin_version_get(ResourcePlugin_field_default): + def handle_field_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get: + return "1.5.6" + + class ResourcePlugin_version_put(ResourcePlugin_field_default): + def handle_field_put(self, resource: Info_put, params: RestRequestParams_PUT) -> Info_put: + return "42" + + class ResourcePlugin_Info(ResourcePlugin_RestResourceBase_default): + def handle_resource_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get: + return Info_get(version="65.45", api_version="98.321") + + class Info_get(RestResourceBase): + # test plugin injection within annotation + # + test plugin on a simple field + version: Annotated[str, RestField(plugin=ResourcePlugin_version_get)] + api_version: str + + class Info_put(RestResourceBase): + # test plugin injection within annotation + # + test plugin on a simple field + version: Annotated[str, RestField(plugin=ResourcePlugin_version_put)] + api_version: str + + @register_rest_rootpoint + class RootApp(RestResourceBase): + # test plugin injection within Field value + # + test plugin on a RestResourceBase field + info: Info_get = RestField( + default=Info_get(version="0.0.1", api_version="0.0.2"), + plugin=ResourcePlugin_Info, + ) + info_put: Info_put = RestField( + default=Info_put(version="0.0.1", api_version="0.0.2"), + ) + info2: Info_get = RestField(default=Info_get(version="0.0.2", api_version="0.0.3")) + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[Info_get.__name__] = Info_get + globals()[Info_put.__name__] = Info_put + globals()[RootApp.__name__] = RootApp + + +def init_bad_plugin1(): + # plugin not inheriting from the right base type + class ResourcePlugin_TestResource: + ... + + class TestResource(RestResourceBase): + tetvaluestr: Annotated[str, RestField(plugin=ResourcePlugin_TestResource)] + + @register_rest_rootpoint + class RootApp2(RestResourceBase): + test: TestResource = RestField(default=TestResource(tetvaluestr="testvalue")) + + RootApp2() + + +class Test_RestAPI_Plugin_PUT(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_put_field_version_fieldplugin(self): + self.testapp.process_request("/info_put/version", rsrc_verb.PUT, '"1.5.6"') + + result = self.testapp.process_request("/info_put", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "42", "api_version": "0.0.2"}') + + result = self.testapp.process_request("/info_put/version", rsrc_verb.GET) + + self.assertEqual(result.get_result(), '"42"') + + def test_put_field_version_resourceplugin(self): + self.testapp.process_request("/info_put", rsrc_verb.PUT, '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("/info_put", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "42", "api_version": "98.321"}') + + +class Test_RestAPI_Plugin_GET(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + self.testapp = RootApp() + + def test_get_root(self): + result = self.testapp.process_request("/", rsrc_verb.GET) + self.assertEqual(result.get_result(), "{}") + + def test_get_version(self): + result = self.testapp.process_request("/info", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("/info2", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') + + def test_get_version__trailing_slash(self): + result = self.testapp.process_request("/info/", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("/info//", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("/info///", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("/info2/", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') + + result = self.testapp.process_request("/info2//", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') + + result = self.testapp.process_request("/info2///", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') + + def test_get_version__multiple_slash(self): + result = self.testapp.process_request("//info", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("///info", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') + + result = self.testapp.process_request("//info2", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') + + result = self.testapp.process_request("///info2", rsrc_verb.GET) + self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') + + def test_get_version__nested_value(self): + result = self.testapp.process_request("/info/api_version", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"98.321"') + + result = self.testapp.process_request("/info/version", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"1.5.6"') + + result = self.testapp.process_request("/info2/api_version", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"0.0.3"') + + result = self.testapp.process_request("/info2/version", rsrc_verb.GET) + self.assertEqual(result.get_result(), '"1.5.6"') + + def test_defect_plugin_field(self): + with self.assertRaises(RestResourcePluginException_InvalidPluginSignature): + init_bad_plugin1() diff --git a/test/test_rest_resource_walker.py b/test/test_rest_resource_walker.py new file mode 100644 index 0000000..5fbc619 --- /dev/null +++ b/test/test_rest_resource_walker.py @@ -0,0 +1,155 @@ +from __future__ import annotations +import unittest + +from typing import Optional + +from os import chdir +from pathlib import Path +from io import StringIO +from contextlib import redirect_stdout + +print(__name__) +print(__package__) + +from src.pyrestresource import ( + RestField, + RestResourceBase, +) + +from src.pyrestresource.rest_resource_walker import ( + RestResourceWalker_Root, + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, +) + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +class RestResourceWalker_Sub_T_Dict_TEST_Print(RestResourceWalker_Sub_T_Dict): + cls_counter: dict[str, int] = {} + + def process(self) -> None: + counter = self.cls_counter + if self.resource_name not in counter: + counter[self.resource_name] = 0 + counter[self.resource_name] = counter[self.resource_name] + 1 + + print(f"DICT {self.resource_name} {counter[self.resource_name]}") + + +class RestResourceWalker_Sub_RestFields_TEST_Print(RestResourceWalker_Sub_RestFields): + cls_counter: dict[str, int] = {} + + def process(self) -> None: + counter = self.cls_counter + if self.resource_name not in counter: + counter[self.resource_name] = 0 + counter[self.resource_name] = counter[self.resource_name] + 1 + + print(f"FIELD {self.resource_name} {counter[self.resource_name]}") + + +class RestResourceWalker_Sub_RestResourceBase_TEST_Print(RestResourceWalker_Sub_RestResourceBase): + cls_counter: dict[str, int] = {} + + def process(self) -> None: + counter = self.cls_counter + if self.resource_name not in counter: + counter[self.resource_name] = 0 + counter[self.resource_name] = counter[self.resource_name] + 1 + + print(f"RestResource {self.resource_name} {counter[self.resource_name]}") + + +class RestResourceWalker_Root_TEST_Print(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict_TEST_Print, + RestResourceWalker_Sub_RestFields_TEST_Print, + RestResourceWalker_Sub_RestResourceBase_TEST_Print, + ] + + +def init_classes(): + class Info(RestResourceBase): + version: str + api_version: str + + class People(RestResourceBase): + last_name: str + + class RootApp(RestResourceBase): + info: Info = RestField(default=Info(version="0.0.1", api_version="0.0.2")) + info2: Info = Info(version="0.0.2", api_version="0.0.3") + peoples: dict[str, People] = { + "john": People(last_name="Doe"), + "jane": People(last_name="Roe"), + } + test_string: str = "test value" + test_string_opt: Optional[str] = None + test_int: int = 42 + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[Info.__name__] = Info + globals()[People.__name__] = People + globals()[RootApp.__name__] = RootApp + + +class Test_Walker(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + + def test_walk_class(self): + RestResourceWalker_Sub_T_Dict_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestFields_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestResourceBase_TEST_Print.cls_counter = {} + test = RestResourceWalker_Root_TEST_Print(RootApp) + with redirect_stdout(StringIO()) as capted_stdout: + test.process({}) + self.assertIn("RestResource info 1", capted_stdout.getvalue()) + self.assertIn("RestResource info2 1", capted_stdout.getvalue()) + self.assertIn("DICT peoples 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_string 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_string_opt 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_int 1", capted_stdout.getvalue()) + self.assertIn("FIELD version 1", capted_stdout.getvalue()) + self.assertIn("FIELD version 2", capted_stdout.getvalue()) + self.assertIn("FIELD api_version 1", capted_stdout.getvalue()) + self.assertIn("FIELD api_version 2", capted_stdout.getvalue()) + self.assertIn("RestResource peoples 1", capted_stdout.getvalue()) + self.assertIn("FIELD last_name 1", capted_stdout.getvalue()) + + def test_walk_obj(self): + RestResourceWalker_Sub_T_Dict_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestFields_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestResourceBase_TEST_Print.cls_counter = {} + instRootApp = RootApp() + test = RestResourceWalker_Root_TEST_Print(instRootApp) + with redirect_stdout(StringIO()) as capted_stdout: + test.process({}) + self.assertIn("RestResource info 1", capted_stdout.getvalue()) + self.assertIn("RestResource info2 1", capted_stdout.getvalue()) + self.assertIn("DICT peoples 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_string 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_string_opt 1", capted_stdout.getvalue()) + self.assertIn("FIELD test_int 1", capted_stdout.getvalue()) + self.assertIn("FIELD version 1", capted_stdout.getvalue()) + self.assertIn("FIELD version 2", capted_stdout.getvalue()) + self.assertIn("FIELD api_version 1", capted_stdout.getvalue()) + self.assertIn("FIELD api_version 2", capted_stdout.getvalue()) + self.assertIn("RestResource peoples 1", capted_stdout.getvalue()) + self.assertIn("FIELD last_name 1", capted_stdout.getvalue()) + + def test_walk_obj_nested_RestResource(self): + RestResourceWalker_Sub_T_Dict_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestFields_TEST_Print.cls_counter = {} + RestResourceWalker_Sub_RestResourceBase_TEST_Print.cls_counter = {} + instRootApp = RootApp() + test = RestResourceWalker_Root_TEST_Print(instRootApp.info) + with redirect_stdout(StringIO()) as capted_stdout: + test.process({}) + self.assertIn("FIELD version 1", capted_stdout.getvalue()) + self.assertIn("FIELD api_version 1", capted_stdout.getvalue()) diff --git a/test/test_rest_resource_walker_tree.py b/test/test_rest_resource_walker_tree.py new file mode 100644 index 0000000..fb2d400 --- /dev/null +++ b/test/test_rest_resource_walker_tree.py @@ -0,0 +1,149 @@ +from __future__ import annotations +import unittest + +from typing import Optional + +from os import chdir +from pathlib import Path + + +print(__name__) +print(__package__) + +from src.pyrestresource import ( + RestField, + RestResourceBase, +) + +from src.pyrestresource.rest_resource_walker import ( + RestResourceWalkerFutureResult, + RestResourceWalker_Root, + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, +) + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +class RestResourceWalkerFutureResult_RestFields_Test(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + res = dict() + res[self.source.resource_name] = False + return res + + +class RestResourceWalker_Sub_RestFields_TEST_Print(RestResourceWalker_Sub_RestFields): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_RestFields_Test + + +class RestResourceWalkerFutureResult_RestResourceBase_Test(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + res = dict() + res[self.source.resource_name] = dict() + for subres in result: + res[self.source.resource_name] = res[self.source.resource_name] | subres + return res + + +class RestResourceWalker_Sub_RestResourceBase_TEST_Print(RestResourceWalker_Sub_RestResourceBase): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_RestResourceBase_Test + + +class RestResourceWalkerFutureResult_Dict_Test(RestResourceWalkerFutureResult[dict]): + def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: + res = dict() + for subres in result: + res = res | subres + return res + + +class RestResourceWalker_Sub_T_Dict_TEST_Print(RestResourceWalker_Sub_T_Dict): + cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_Dict_Test + + +class RestResourceWalker_Root_TEST_Print(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict_TEST_Print, + RestResourceWalker_Sub_RestFields_TEST_Print, + RestResourceWalker_Sub_RestResourceBase_TEST_Print, + ] + + +def init_classes(): + class Info(RestResourceBase): + version: str + api_version: str + + class People(RestResourceBase): + last_name: str + + class RootApp(RestResourceBase): + info: Info = RestField(default=Info(version="0.0.1", api_version="0.0.2")) + info2: Info = Info(version="0.0.2", api_version="0.0.3") + peoples: dict[str, People] = { + "john": People(last_name="Doe"), + "jane": People(last_name="Roe"), + } + test_string: str = "test value" + test_string_opt: Optional[str] = None + test_int: int = 42 + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[Info.__name__] = Info + globals()[People.__name__] = People + globals()[RootApp.__name__] = RootApp + + +class Test_Walker_tree(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + init_classes() + + def test_walk_class(self): + test = RestResourceWalker_Root_TEST_Print(RootApp) + res = test.process() + self.assertDictEqual( + res, + { + "/": { + "info": {"version": False, "api_version": False}, + "info2": {"version": False, "api_version": False}, + "peoples": {"last_name": False}, + "test_string": False, + "test_string_opt": False, + "test_int": False, + } + }, + ) + + def test_walk_obj(self): + instRootApp = RootApp() + test = RestResourceWalker_Root_TEST_Print(instRootApp) + res = test.process() + self.assertDictEqual( + res, + { + "/": { + "info": {"version": False, "api_version": False}, + "info2": {"version": False, "api_version": False}, + "peoples": {"last_name": False}, + "test_string": False, + "test_string_opt": False, + "test_int": False, + } + }, + ) + + def test_walk_obj_nested_RestResource(self): + instRootApp = RootApp() + test = RestResourceWalker_Root_TEST_Print(instRootApp.info) + res = test.process() + self.assertDictEqual( + res, + { + "/": {"version": False, "api_version": False}, + }, + ) diff --git a/test/test_rest_webserver.py b/test/test_rest_webserver.py new file mode 100644 index 0000000..dbe2af7 --- /dev/null +++ b/test/test_rest_webserver.py @@ -0,0 +1,360 @@ +from __future__ import annotations +import unittest +from unittest.mock import patch +from os import chdir +from pathlib import Path +from typing import Optional +from uuid import UUID, uuid4 +from time import time, sleep +import json +import uvicorn +import socket +import requests +from contextlib import closing +from multiprocessing import Process +from requests.adapters import HTTPAdapter +import coverage + +print(__name__) +print(__package__) + +from src.pyrestresource import ( + RestField, + register_rest_rootpoint, + RestResourceBase, + rsrc_verb, + RestRequestParams_GET, + RestRequestParams_POST, + RestRequestParams_Dict_GET, + T_SupportedRESTFields, +) +from pprint import pprint + +from test import ThreadedUvicorn + +testdir_path = Path(__file__).parent.resolve() +chdir(testdir_path.parent.resolve()) + + +# to allow mock-ing, all the tested classes are in a function +def init_classes(): + class Info(RestResourceBase): + version: str + api_version: str + + class Patch(RestResourceBase): + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + class Profile(RestResourceBase): + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + class Game(RestResourceBase): + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + profiles: dict[UUID, Profile] = {} + patchs: dict[UUID, Patch] = {} + + Patch_1 = Patch(uuid="cee1e870-65fa-11ee-8c99-0242ac120002", shortname="testPatch1") + Patch_2 = Patch(uuid="d385a1d2-65fa-11ee-8c99-0242ac120002", shortname="testPatch2") + + class User(RestResourceBase): + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) + name: str + secret: str = RestField(..., exclude=True) + + User1 = User( + uuid="8da57a3c-661f-11ee-8c99-0242ac120002", + name="chacha", + secret="la blanquette est bonne", + ) + + class Patch2(RestResourceBase): + uuid: UUID = RestField(default_factory=uuid4, primary_key=True) + shortname: str + name: Optional[str] = None + description: Optional[str] = None + + @register_rest_rootpoint + class RootApp(RestResourceBase): + testValueRoot: float = 3.14 + info: Info = Info(version="0.0.1", api_version="0.0.2") + games: dict[UUID, Game] = { + UUID("9b0381d4-65f6-11ee-8c99-0242ac120002"): Game( + uuid="9b0381d4-65f6-11ee-8c99-0242ac120002", + shortname="testGame Origin", + description="test Game Desc Origin", + patchs={Patch_1.uuid: Patch_1}, + profiles={ + UUID("aee1e870-65fa-11ee-8c99-0242ac120002"): Profile( + uuid="aee1e870-65fa-11ee-8c99-0242ac120002", + shortname="testprofile", + ) + }, + ) + } + patchs: dict[UUID, Patch] = {Patch_1.uuid: Patch_1, Patch_2.uuid: Patch_2} + users: dict[UUID, User] = {User1.uuid: User1} + + patchs2: dict[UUID, Patch2] = {} + + # this add the classes to globals to allow using them later on + # => this is only for uinit-testing purpose and is not needed in real use + globals()[Info.__name__] = Info + globals()[Game.__name__] = Game + globals()[User.__name__] = User + globals()[Profile.__name__] = Profile + globals()[Patch.__name__] = Patch + globals()[Patch2.__name__] = Patch2 + globals()[RootApp.__name__] = RootApp + + +def find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return "localhost", s.getsockname()[1] + + +class Test_RestAPI_WebServer(unittest.TestCase): + def setUp(self) -> None: + chdir(testdir_path.parent.resolve()) + + def test_nomal_AllCmd_games(self): + ip, port = find_free_port() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # Fetching games + response = s.get(f"http://{ip}:{port}/games") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIsInstance(data, list) + self.assertEqual( + data, + ["9b0381d4-65f6-11ee-8c99-0242ac120002"], + ) + + # Add a new one (with all values setted) + response = s.post( + f"http://{ip}:{port}/games", + json={ + "shortname": "test", + "name": "nametest", + "description": "test Game Desc", + }, + ) + self.assertEqual(response.status_code, 201) + data = response.json() + NEW_GAME_UUID = UUID(data) + + # Fetching games again + response = s.get(f"http://{ip}:{port}/games") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIsInstance(data, list) + self.assertEqual( + data, + ["9b0381d4-65f6-11ee-8c99-0242ac120002", str(NEW_GAME_UUID)], + ) + + # Getting accurate values of created element + response = s.get(f"http://{ip}:{port}/games/{str(NEW_GAME_UUID)}") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIsInstance(data, dict) + self.assertIsInstance(data, dict) + self.assertIn("shortname", data) + self.assertIn("name", data) + self.assertIn("description", data) + self.assertIn("uuid", data) + NEW_GAME_UUID = UUID(data["uuid"]) + del data["uuid"] + self.assertDictEqual( + data, + { + "name": "nametest", + "shortname": "test", + "description": "test Game Desc", + }, + ) + + # removing the new one + response = s.delete(f"http://{ip}:{port}/games/{str(NEW_GAME_UUID)}") + self.assertEqual(response.status_code, 200) + + # Fetching games again + response = s.get(f"http://{ip}:{port}/games") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual( + data, + ["9b0381d4-65f6-11ee-8c99-0242ac120002"], + ) + + # Getting accurate values + response = s.get(f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIsInstance(data, dict) + self.assertIsInstance(data, dict) + self.assertIn("shortname", data) + self.assertIn("name", data) + self.assertIn("description", data) + self.assertIn("uuid", data) + NEW_GAME_UUID = UUID(data["uuid"]) + del data["uuid"] + self.assertDictEqual( + data, + { + "name": None, + "shortname": "testGame Origin", + "description": "test Game Desc Origin", + }, + ) + + # Update values + response = s.put( + f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002", + json={ + "name": "MyName", + }, + ) + self.assertEqual(response.status_code, 201) + + # Getting accurate values + response = s.get(f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIsInstance(data, dict) + self.assertIsInstance(data, dict) + self.assertIn("shortname", data) + self.assertIn("name", data) + self.assertIn("description", data) + self.assertIn("uuid", data) + NEW_GAME_UUID = UUID(data["uuid"]) + del data["uuid"] + self.assertDictEqual( + data, + { + "name": "MyName", + "shortname": "testGame Origin", + "description": "test Game Desc Origin", + }, + ) + + # removing original element + response = s.delete(f"http://{ip}:{port}/games?API_key={str(NEW_GAME_UUID)}") + self.assertEqual(response.status_code, 200) + + # Fetching games again + response = s.get(f"http://{ip}:{port}/games") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertTrue(len(data) == 0) + finally: + s.close() + server.stop() + + @unittest.skip + def test_perf_dict(self): + print(f"SOCKET PERF TEST") + n_loop = 10000 + + ip, port = find_free_port() + init_classes() + + server = ThreadedUvicorn(uvicorn.Config(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True)) + server.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + start = time() + for _ in range(n_loop): + s.get(f"http://{ip}:{port}/users/8da57a3c-661f-11ee-8c99-0242ac120002") + end = time() + print(f"GET 1st level dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + newUUID = uuid4() + s.post( + f"http://{ip}:{port}/users?API_key={newUUID}", + json={"name": "testUser", "secret": "test"}, + ) + end = time() + + print(f"POST 1st level dict (API_key): {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + newUUID = uuid4() + s.post( + f"http://{ip}:{port}/users?API_key={str(newUUID)}", + json={"name": "testUser", "secret": "test"}, + ) + s.get(f"http://{ip}:{port}/users/{newUUID}") + end = time() + print(f"POST/GET 1st level dict (API_key): {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + response = s.post(f"http://{ip}:{port}/users", '{"name": "testUser", "secret": "test"}') + s.get(f"http://{ip}:{port}/users/{response.json()}") + end = time() + print(f"POST/GET 1st level dict (autokey): {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + s.put( + f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname", + json="TestValue!!", + ) + s.get(f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/shortname") + end = time() + print(f"PUT/GET 1st level (value) dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + s.get(f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002") + end = time() + print(f"GET 2nd level dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + s.get( + f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + ) + end = time() + print(f"GET 2nd level (value) dict: {int(n_loop/(end-start))} Req/s") + + start = time() + for _ in range(n_loop): + s.put( + f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + json="TestValue!!", + ) + s.get( + f"http://{ip}:{port}/games/9b0381d4-65f6-11ee-8c99-0242ac120002/patchs/cee1e870-65fa-11ee-8c99-0242ac120002/shortname", + ) + end = time() + print(f"PUT/GET 2nd level (value) dict: {int(n_loop/(end-start))} Req/s") + + finally: + s.close() + server.stop() diff --git a/test/test_test_module.py b/test/test_test_module.py deleted file mode 100644 index 9f4280d..0000000 --- a/test/test_test_module.py +++ /dev/null @@ -1,35 +0,0 @@ -# pyrestresource (c) by chacha -# -# pyrestresource is licensed under a -# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. -# -# You should have received a copy of the license along with this -# work. If not, see . - -import unittest -from os import chdir -from io import StringIO -from contextlib import redirect_stdout,redirect_stderr -from pathlib import Path - -print(__name__) -print(__package__) - -from src import pyrestresource - -testdir_path = Path(__file__).parent.resolve() -chdir(testdir_path.parent.resolve()) - -class Testtest_module(unittest.TestCase): - def setUp(self) -> None: - chdir(testdir_path.parent.resolve()) - - def test_version(self): - self.assertNotEqual(pyrestresource.__version__,"?.?.?") - - def test_test_module(self): - - with redirect_stdout(StringIO()) as capted_stdout, redirect_stderr(StringIO()) as capted_stderr: - self.assertEqual(pyrestresource.test_function(41),42) - self.assertEqual(len(capted_stderr.getvalue()),0) - self.assertEqual(capted_stdout.getvalue().strip(),"Hello world !") \ No newline at end of file