diff --git a/src/pyrestresource/__init__.py b/src/pyrestresource/__init__.py index 3722afd..d129cc3 100644 --- a/src/pyrestresource/__init__.py +++ b/src/pyrestresource/__init__.py @@ -19,10 +19,8 @@ from typing import TYPE_CHECKING from .__metadata__ import __version__, __Summuary__, __Name__ -from .rest_resource import ( - register_rest_rootpoint, - RestResourceBase, -) +from .rest_resource import RestResourceBase +from .rest_resource_rootpoint import register_rest_rootpoint from .rest_types import rsrc_verb, T_SupportedRESTFields @@ -34,6 +32,7 @@ if TYPE_CHECKING: T_T_DictKey, T_DictValues, T_T_DictValues, + RestResourceException, ) from .rest_request_opt import ( @@ -52,6 +51,8 @@ from .rest_resource_plugin import ( ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, ResourcePlugin_dict_default, + RestResourcePluginException, + RestResourcePluginException_InvalidPluginSignature, ) from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule from .rest_login import RestResourceBaseLogin, UserLogin diff --git a/src/pyrestresource/helpers.py b/src/pyrestresource/helpers.py index a08067b..e8c13d1 100644 --- a/src/pyrestresource/helpers.py +++ b/src/pyrestresource/helpers.py @@ -1,6 +1,7 @@ # pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring from __future__ import annotations + from uuid import UUID import json diff --git a/src/pyrestresource/rest_ACL.py b/src/pyrestresource/rest_ACL.py index 3ebdbaa..e7158e4 100644 --- a/src/pyrestresource/rest_ACL.py +++ b/src/pyrestresource/rest_ACL.py @@ -1,10 +1,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING from pydantic import BaseModel from enum import Enum, auto from .rest_types import rsrc_verb +if TYPE_CHECKING is True: + from .rest_login import UserLogin + class ACL_target(BaseModel): pass @@ -13,6 +17,10 @@ class ACL_target(BaseModel): class ACL_target_user(ACL_target): name: str + @classmethod + def from_user_login(cls, user_login: UserLogin) -> ACL_target_user: + return cls(name=user_login.username) + class ACL_target_user_Annonymous(ACL_target): name: str = "__ANNONYMOUS__" diff --git a/src/pyrestresource/rest_login.py b/src/pyrestresource/rest_login.py index 0fa0c2e..707e421 100644 --- a/src/pyrestresource/rest_login.py +++ b/src/pyrestresource/rest_login.py @@ -12,21 +12,19 @@ """CLI interface module""" from __future__ import annotations - from typing import Optional, ClassVar, TYPE_CHECKING + from secrets import token_hex, compare_digest from datetime import datetime - from pydantic import BaseModel, Field from .rest_types import rsrc_verb from .rest_resource import RestResourceBase +from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule, ACL_target_user +from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default -from .rest_request import RestRequest, RestRequestParams_GET -from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule - -if TYPE_CHECKING or True: - from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default +if TYPE_CHECKING is True: + from .rest_request import RestRequest, RestRequestParams_GET class UserLogin(BaseModel): @@ -44,20 +42,11 @@ class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): ar_UserLogin: list[UserLogin] = [] def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login: - print("hook GET") - print(resource) - print(params) - return resource + return Login(username=self.get_user_login()) def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login: - print("hook PUT") - # print(self.get_ar_userlogin()) - print(resource.username) - print(resource.secret) - token = self.user_login(resource.username, resource.secret) self.set_resp_cookie_value("Authorization", f"Bearer {token}") - return resource @@ -93,7 +82,7 @@ class RestResourceBaseLogin(RestResourceBase): del self._ar_user_session[auth_cookie] raise RuntimeError("session timeout ! (session reseted)") - request.set_user(self._ar_user_session[auth_cookie].user_login.username) + request.set_user(ACL_target_user(name=self._ar_user_session[auth_cookie].user_login.username)) return print("Invalid session") diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index c5d08cd..0d37b7d 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -4,19 +4,16 @@ from __future__ import annotations from typing import ( Any, - Optional, Generic, + TYPE_CHECKING, ) + from re import sub from urllib.parse import urlparse, parse_qs -from http.cookies import SimpleCookie - from pydantic import BaseModel, Field - from typeguard import check_type -from .rest_types import rsrc_verb, T_SupportedRESTFields, T_AllSupportedFields - +from .rest_types import rsrc_verb, T_AllSupportedFields from .rest_request_opt import ( RestRequestParams_POST, RestRequestParams_DELETE, @@ -28,10 +25,13 @@ from .rest_request_opt import ( _T_RestRequestParams_GET, _T_RestRequestParams_PUT, ) - from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group from .helpers import parse_dict_cookies +if TYPE_CHECKING is True: + from typing import Optional + from .rest_types import T_SupportedRESTFields + class RequestFactory( Generic[ @@ -182,7 +182,10 @@ class RestRequest(Generic[_T_RestRequestParams]): return self.result def set_user(self, user: ACL_target_user): - self.user: ACL_target_user = user + self.user = user + + def get_user(self): + return self.user def add_group(self, group: ACL_target_group): self.groups.append(group) diff --git a/src/pyrestresource/rest_request_opt.py b/src/pyrestresource/rest_request_opt.py index 10cc360..599bdbe 100644 --- a/src/pyrestresource/rest_request_opt.py +++ b/src/pyrestresource/rest_request_opt.py @@ -1,14 +1,17 @@ # pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring from __future__ import annotations +from typing import Generic, Optional, TypeVar, TYPE_CHECKING -from typing import Optional, Generic, TypeVar from pydantic import BaseModel, Extra from .rest_types import ( _T_DictKey, ) +if TYPE_CHECKING is True: + pass + class RestRequestParams(BaseModel, extra=Extra.allow): pass diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index 75525f5..bd42c32 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -1,250 +1,35 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# pyrestresource(c) by chacha -# -# pyrestresource is licensed under a -# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. -# -# You should have received a copy of the license along with this -# work. If not, see . - -# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring - -"""CLI interface module""" from __future__ import annotations - -from abc import ABC from typing import ( Any, - cast, ClassVar, - get_args, - get_origin, Optional, TYPE_CHECKING, ) + +from abc import ABC import json -from pydantic.fields import FieldInfo, Field from pydantic import BaseModel from .helpers import _JSONEncoder from .rest_types import rsrc_verb, _T_SupportedRESTFields -from .rest_resource_plugin import ( - ResourcePlugin_field, - ResourcePlugin_RestResourceBase, - ResourcePlugin_dict, -) from .rest_ACL import ( ACL_record, ACL_target_user, ACL_target_group, - ACL_target_user_Annonymous, ACL_target_group_Any, ACL_rule, ) - -from .rest_resource_walker import ( - RestResourceWalkerFutureResult, - RestResourceWalker_Root, - RestResourceWalker_Sub_T_Dict, - RestResourceWalker_Sub_RestFields, - RestResourceWalker_Sub_RestResourceBase, -) - from .rest_request import RestRequest -if TYPE_CHECKING: +if TYPE_CHECKING is True: from .rest_types import ( - T_ListIndex, - T_ListSize, - T_DictKey, T_T_DictKey, - T_DictValues, T_T_DictValues, - T_SupportedRESTFields, ) -class RestResourceWalkerFutureResult_RestResourceBase_tree_exclude(RestResourceWalkerFutureResult[dict]): - def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: - res = {} - res[self.source.resource_name] = dict() - for subres in result: - key = next(iter(subres)) - if ( - key in self.source.annotation._model_dump_excluded_ # pylint: disable=protected-access - and self.source.annotation._model_dump_excluded_[key] is True # pylint: disable=protected-access - ): - res[self.source.resource_name] = res[self.source.resource_name] | {key: True} - else: - res[self.source.resource_name] = res[self.source.resource_name] | subres - return res - - -class RestResourceWalkerFutureResult_Dict_tree_exclude(RestResourceWalkerFutureResult[dict]): - def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: - res = {} - for subres in result: - res = res | subres - return res - - -class RestResourceWalker_Sub_T_Dict__tree_exclude(RestResourceWalker_Sub_T_Dict): - cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_Dict_tree_exclude - - -class RestResourceWalker_Sub_RestResourceBase__tree_exclude(RestResourceWalker_Sub_RestResourceBase): - cls_RestResourceWalkerFutureResult = RestResourceWalkerFutureResult_RestResourceBase_tree_exclude - - -class RestResourceWalker_Root__tree_exclude(RestResourceWalker_Root): - cls_RestResourceWalker_Sub = [ - RestResourceWalker_Sub_T_Dict__tree_exclude, - RestResourceWalker_Sub_RestFields, - RestResourceWalker_Sub_RestResourceBase__tree_exclude, - ] - - -class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): - def process(self) -> None: - datatype = get_args(self.annotation) - - # checking compatibility - if not get_origin(datatype[1]) is None: - raise RuntimeError("complex dict types are not supported (should create a RestResourceBase container)") - if not datatype[0] in _T_SupportedRESTFields: - raise RuntimeError(f"Unsupported Dict Field value type in class (key)") - - # preprocessing types / structure - if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): - self.parent.annotation._dict_key_type_[self.resource_name] = datatype[0] # pylint: disable=protected-access - self.parent.annotation._dict_value_type_[self.resource_name] = datatype[1] # pylint: disable=protected-access - self.parent.annotation._model_dump_excluded_[self.resource_name] = True # pylint: disable=protected-access - - self.resource.exclude = True - self.parent.resource.model_rebuild(force=True) - - self.parent.annotation._ACL_record_[self.resource_name] = [] - - if ( - isinstance(self.resource, FieldInfo) - and self.resource.json_schema_extra is not None - and type(self.resource.json_schema_extra) is dict - ): - if "plugin" in self.resource.json_schema_extra: - plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"] - if not isinstance(plugin_dict, ResourcePlugin_dict): - raise RuntimeError("Wrong plugin signature provided") - self.parent.annotation._plugins_[self.resource_name] = plugin_dict - # print("ADD DICT PLUGIN") - - if "ACL" in self.resource.json_schema_extra: - if isinstance(self.resource.json_schema_extra["ACL"], list): - # print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}") - self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] - else: - raise RuntimeError("ACL must be a list()") - - else: - raise RuntimeError("dict must be contained in a RestResourceBase") - - -class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFields): - def process(self) -> None: - if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): - import pprint - - # print("1aaaaaaaaaa") - # pprint.pprint(self.resource.json_schema_extra) - # pprint.pprint(self.annotation) - # pprint.pprint(self.resource.exclude) - - self.parent.annotation._ACL_record_[self.resource_name] = [] - - if ( - isinstance(self.resource, FieldInfo) - and self.resource.json_schema_extra is not None - and type(self.resource.json_schema_extra) is dict - ): - # print("aaaaaaaaaa") - - if "primary_key" in self.resource.json_schema_extra and self.resource.json_schema_extra["primary_key"] is True: - if self.parent.annotation._primary_key_ is not None: - raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}") - self.parent.annotation._primary_key_ = self.resource_name - self.parent.annotation._ACL_record_[self.resource_name] = [ - ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY) - ] - - if "plugin" in self.resource.json_schema_extra: - plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"] - if not isinstance(plugin_field, ResourcePlugin_field): - raise RuntimeError("Wrong plugin signature provided") - self.parent.annotation._plugins_[self.resource_name] = plugin_field - # print("ADD FIELD PLUGIN") - - if "ACL" in self.resource.json_schema_extra: - if isinstance(self.resource.json_schema_extra["ACL"], list): - # print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}") - self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] - else: - raise RuntimeError("ACL must be a list()") - - else: - raise RuntimeError("fields must be contained in a RestResourceBase") - - -class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_RestResourceBase): - def process(self) -> None: - setattr(self.annotation, "_dict_key_type_", {}) - setattr(self.annotation, "_dict_value_type_", {}) - setattr(self.annotation, "_model_dump_excluded_", {}) - setattr(self.annotation, "_primary_key_", None) - setattr(self.annotation, "_plugins_", {}) - setattr(self.annotation, "_ACL_record_", {}) - - # preprocessing types / structure - if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): - self.parent.annotation._model_dump_excluded_[self.resource_name] = True - self.resource.exclude = True - self.parent.resource.model_rebuild(force=True) - self.parent.annotation._ACL_record_[self.resource_name] = [] - - if ( - isinstance(self.resource, FieldInfo) - and self.resource.json_schema_extra is not None - and type(self.resource.json_schema_extra) is dict - ): - if "plugin" in self.resource.json_schema_extra: - plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] - if not issubclass(plugin_resource, ResourcePlugin_RestResourceBase): - raise RuntimeError(f"Wrong plugin signature provided for {plugin_resource} : {type(plugin_resource)}") - self.parent.annotation._plugins_[self.resource_name] = plugin_resource - # print("ADD RESOURCE PLUGIN") - - if "ACL" in self.resource.json_schema_extra: - if isinstance(self.resource.json_schema_extra["ACL"], list): - # print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}") - self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] - else: - raise RuntimeError("ACL must be a list()") - - -class RestResourceWalker_Root__tree_init(RestResourceWalker_Root): - cls_RestResourceWalker_Sub = [ - RestResourceWalker_Sub_T_Dict__tree_init, - RestResourceWalker_Sub_RestFields__tree_init, - RestResourceWalker_Sub_RestResourceBase__tree_init, - ] - - -def register_rest_rootpoint(klass: type[RestResourceBase]): - RestResourceWalker_Root__tree_init(klass).process() - return klass - - class RestResourceBase(ABC, BaseModel, validate_assignment=True): # _resp_cookies: ClassVar[dict[str, str]] = {} _dict_key_type_: ClassVar[dict[str, T_T_DictKey]] = {} @@ -264,31 +49,31 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): ] ] = {} - def _check_acl(self, user: str, groups: list[ACL_target_group], verb: rsrc_verb, field: str, is_self: bool = True): - print(f"evaluate self ACLs rule: {self._ACL_record_}") - print(f"user: {user}") - print(f"groups: {groups}") + def _check_acl(self, user: ACL_target_user, groups: list[ACL_target_group], verb: rsrc_verb, field: str, is_self: bool = True): + # print(f"evaluate self ACLs rule: {self._ACL_record_}") + # print(f"user: {user}") + # print(f"groups: {groups}") if is_self and verb is rsrc_verb.GET and self.model_fields[field].exclude is True: # print("ALLOWED (excluded field)") return for acl in self._ACL_record_[field]: - print(f"evaluate ACL rule: {acl}") + # print(f"evaluate ACL rule: {acl}") if verb in acl.verbs: if isinstance(acl.target, ACL_target_user): - if user == acl.target.name: + if user == acl.target: if acl.rule is ACL_rule.ALLOW: - print("ALLOWED (user)") + # print("ALLOWED (user)") return raise RuntimeError(f"Not allowed access detected: {field}") elif isinstance(acl.target, ACL_target_group): - if acl.target.name in groups or isinstance(acl.target, ACL_target_group_Any): + if isinstance(acl.target, ACL_target_group_Any) or any(_ for _ in groups if _.name == acl.target.name): if acl.rule is ACL_rule.ALLOW: - print("ALLOWED (group)") + # print("ALLOWED (group)") return raise RuntimeError(f"Not allowed access detected: {field}") else: raise RuntimeError(f"Wrong ACL target type: {field}") - print("ALLOWED (Default)") + # print("ALLOWED (Default)") def check_acl_field(self, request: RestRequest, req_index: int = 0) -> None: """Check ACL on requested field access""" @@ -334,10 +119,10 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if b"content-type" in scope["headers"]: assert scope["headers"][b"content-type"] == b"application/json" - import pprint + # import pprint - print("----REC HEADER ---") - pprint.pprint(scope["headers"]) + # print("----REC HEADER ---") + # pprint.pprint(scope["headers"]) body = await self.read_body(receive) verb = rsrc_verb[scope["method"]] diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index 5e99bd7..77cfd42 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -1,32 +1,22 @@ from __future__ import annotations -import abc from typing import Optional, cast, TypeVar, Generic, Self, TYPE_CHECKING +import abc + from .rest_types import ( rsrc_verb, T_SupportedRESTFields, T_DictKey, _T_SupportedRESTFields, T_Dict, - T_T_DictValues, T_DictValues, ) from .rest_resource import RestResourceBase -from .rest_request import RequestFactory, RestRequest - +from .rest_request import RequestFactory from .rest_resource_plugin import ( ResourcePlugin_field, ResourcePlugin_RestResourceBase, ) - -from .rest_ACL import ( - ACL_target_user, - ACL_target_group, - ACL_target_user_Annonymous, - ACL_target_group_Any, - ACL_rule, -) - from .rest_request_opt import ( RestRequestParams_POST, RestRequestParams_DELETE, @@ -43,16 +33,9 @@ from .rest_request_opt import ( _T_RestRequestParams_PUT, ) -from .rest_resource_handler_walker import RestResourceWalker_Root__handler - -if TYPE_CHECKING: - from .rest_types import ( - T_ListIndex, - T_ListSize, - T_T_DictKey, - T_FieldValue, - ) - +if TYPE_CHECKING is True: + from .rest_types import T_T_DictKey, T_T_DictValues + from .rest_request import RestRequest _T_Resource = TypeVar("_T_Resource", T_DictValues, T_Dict, T_SupportedRESTFields, RestResourceBase) @@ -311,7 +294,7 @@ class ResourceHandler_dict( dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] - _dict: dict[T_DictKey, "T_DictValues"] = cast(dict[T_DictKey, "T_DictValues"], self.resource) + _dict: dict[T_DictKey, T_DictValues] = cast(dict[T_DictKey, T_DictValues], self.resource) if params.API_key is not None: del _dict[dict_key_type(params.API_key)] @@ -335,7 +318,7 @@ class ResourceHandler_dict( _obj = dict_value_type(**self.req.get_data()) - _dict: dict[T_DictKey, "T_DictValues"] = cast(dict[T_DictKey, "T_DictValues"], self.resource) + _dict: dict[T_DictKey, T_DictValues] = cast(dict[T_DictKey, T_DictValues], self.resource) # 1st try/ using request param provided dict API_key if params.API_key is not None: @@ -495,13 +478,13 @@ class ResourceHandler_RestResourceBase( self.resource.check_acl_self(self.req, None) for key, attr in self.resource.model_fields.items(): if key in self.resource._plugins_: - if isinstance(self.resource._plugins_[key], ResourcePlugin_field): + if issubclass(self.resource._plugins_[key], ResourcePlugin_field): plugin_field: ResourcePlugin_field = cast( ResourcePlugin_field, self.resource._plugins_[key](self.req, self.root_resource) ) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_field_get(value, params)) - elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): + elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): plugin_field: ResourcePlugin_field = cast( ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource) ) @@ -523,14 +506,14 @@ class ResourceHandler_RestResourceBase( key = self.req.get_resource_origin(0) if key in self.resource._plugins_: - if isinstance(self.resource._plugins_[key], ResourcePlugin_field): + if issubclass(self.resource._plugins_[key], ResourcePlugin_field): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource), ) value = plugin_rsrc.handle_field_get(value, params) - elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): + elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource), @@ -554,7 +537,7 @@ class ResourceHandler_RestResourceBase( if isinstance(_new_resrc, RestResourceBase): for key, attr in _new_resrc.model_fields.items(): if key in _new_resrc._plugins_: - if isinstance(_new_resrc._plugins_[key], ResourcePlugin_field): + if issubclass(_new_resrc._plugins_[key], ResourcePlugin_field): plugin_field: ResourcePlugin_field = cast( ResourcePlugin_field, _new_resrc._plugins_[key](self.req, self.root_resource) ) diff --git a/src/pyrestresource/rest_resource_handler_walker.py b/src/pyrestresource/rest_resource_handler_walker.py index e23ec6f..c0f06ac 100644 --- a/src/pyrestresource/rest_resource_handler_walker.py +++ b/src/pyrestresource/rest_resource_handler_walker.py @@ -12,14 +12,7 @@ """CLI interface module""" from __future__ import annotations - -from typing import ( - ClassVar, - get_args, - get_origin, - Optional, - TYPE_CHECKING, -) +from typing import TYPE_CHECKING from .rest_resource_walker import ( RestResourceWalkerFutureResult, @@ -29,6 +22,9 @@ from .rest_resource_walker import ( RestResourceWalker_Sub_RestResourceBase, ) +if TYPE_CHECKING is True: + from typing import Optional + class RestResourceWalkerFutureResult_RestResourceBase_handler(RestResourceWalkerFutureResult[dict]): def process_future(self, result: Optional[list[dict]]) -> Optional[dict]: diff --git a/src/pyrestresource/rest_resource_plugin.py b/src/pyrestresource/rest_resource_plugin.py index 72290f7..4e7bcea 100644 --- a/src/pyrestresource/rest_resource_plugin.py +++ b/src/pyrestresource/rest_resource_plugin.py @@ -1,6 +1,6 @@ from __future__ import annotations - from typing import Optional, Generic, TYPE_CHECKING + from abc import abstractmethod, ABC from .rest_types import ( @@ -8,12 +8,12 @@ from .rest_types import ( _T_DictKey, TV_SupportedRESTFields, TV_RestResourceBase, + RestResourceException, ) - from .rest_request import RestRequest - -if TYPE_CHECKING or True: +if TYPE_CHECKING is True: + from .rest_resource import RestResourceBase from .rest_request_opt import ( RestRequestParams_GET, RestRequestParams_PUT, @@ -27,19 +27,24 @@ if TYPE_CHECKING or True: ) +class RestResourcePluginException(RestResourceException): + pass + + +class RestResourcePluginException_InvalidPluginSignature(RestResourcePluginException): + pass + + class ResourcePlugin(ABC): - def __init__(self, request: RestRequest, root_resource: "RestResourceBase") -> None: + def __init__(self, request: RestRequest, root_resource: RestResourceBase) -> None: self.__request: RestRequest = request self.__root_resource: RestRequest = root_resource def user_login(self, user_name: str, user_secret: str) -> str: return self.__root_resource.user_login(user_name, user_secret, self.__request) - """ - def get_ar_userlogin(self): - print("===========") - return self.__root_resource.get_ar_user_login() - """ + def get_user_login(self) -> str: + return self.__request.get_user().name def getr_req_cookie_value(self, key: str) -> Optional[str]: return self.__request.incoming_cookie[key] diff --git a/src/pyrestresource/rest_resource_rootpoint.py b/src/pyrestresource/rest_resource_rootpoint.py new file mode 100644 index 0000000..a13e521 --- /dev/null +++ b/src/pyrestresource/rest_resource_rootpoint.py @@ -0,0 +1,169 @@ +from __future__ import annotations +from typing import ( + get_args, + get_origin, + TYPE_CHECKING, +) + +from pydantic.fields import FieldInfo + +from .rest_resource import RestResourceBase +from .rest_resource_plugin import ( + ResourcePlugin_field, + ResourcePlugin_RestResourceBase, + ResourcePlugin_dict, + RestResourcePluginException_InvalidPluginSignature, +) +from .rest_resource_walker import ( + RestResourceWalker_Root, + RestResourceWalker_Sub_T_Dict, + RestResourceWalker_Sub_RestFields, + RestResourceWalker_Sub_RestResourceBase, +) +from .rest_types import rsrc_verb, _T_SupportedRESTFields +from .rest_ACL import ( + ACL_record, + ACL_target_group_Any, + ACL_rule, +) + +if TYPE_CHECKING is True: + pass + + +class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): + def process(self) -> None: + datatype = get_args(self.annotation) + + # checking compatibility + if not get_origin(datatype[1]) is None: + raise RuntimeError("complex dict types are not supported (should create a RestResourceBase container)") + if not datatype[0] in _T_SupportedRESTFields: + raise RuntimeError(f"Unsupported Dict Field value type in class (key)") + + # preprocessing types / structure + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + self.parent.annotation._dict_key_type_[self.resource_name] = datatype[0] # pylint: disable=protected-access + self.parent.annotation._dict_value_type_[self.resource_name] = datatype[1] # pylint: disable=protected-access + self.parent.annotation._model_dump_excluded_[self.resource_name] = True # pylint: disable=protected-access + + self.resource.exclude = True + self.parent.resource.model_rebuild(force=True) + + self.parent.annotation._ACL_record_[self.resource_name] = [] + + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + ): + if "plugin" in self.resource.json_schema_extra: + plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"] + if not issubclass(plugin_dict, ResourcePlugin_dict): + raise RestResourcePluginException_InvalidPluginSignature() + self.parent.annotation._plugins_[self.resource_name] = plugin_dict + # print("ADD DICT PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + # print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") + + else: + raise RuntimeError("dict must be contained in a RestResourceBase") + + +class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFields): + def process(self) -> None: + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + # import pprint + + # print("1aaaaaaaaaa") + # pprint.pprint(self.resource.json_schema_extra) + # pprint.pprint(self.annotation) + # pprint.pprint(self.resource.exclude) + + self.parent.annotation._ACL_record_[self.resource_name] = [] + + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + ): + # print("aaaaaaaaaa") + + if "primary_key" in self.resource.json_schema_extra and self.resource.json_schema_extra["primary_key"] is True: + if self.parent.annotation._primary_key_ is not None: + raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}") + self.parent.annotation._primary_key_ = self.resource_name + self.parent.annotation._ACL_record_[self.resource_name] = [ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY) + ] + + if "plugin" in self.resource.json_schema_extra: + plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"] + if not issubclass(plugin_field, ResourcePlugin_field): + raise RestResourcePluginException_InvalidPluginSignature() + self.parent.annotation._plugins_[self.resource_name] = plugin_field + # print("ADD FIELD PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + # print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") + + else: + raise RuntimeError("fields must be contained in a RestResourceBase") + + +class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_RestResourceBase): + def process(self) -> None: + setattr(self.annotation, "_dict_key_type_", {}) + setattr(self.annotation, "_dict_value_type_", {}) + setattr(self.annotation, "_model_dump_excluded_", {}) + setattr(self.annotation, "_primary_key_", None) + setattr(self.annotation, "_plugins_", {}) + setattr(self.annotation, "_ACL_record_", {}) + + # preprocessing types / structure + if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): + self.parent.annotation._model_dump_excluded_[self.resource_name] = True + self.resource.exclude = True + self.parent.resource.model_rebuild(force=True) + self.parent.annotation._ACL_record_[self.resource_name] = [] + + if ( + isinstance(self.resource, FieldInfo) + and self.resource.json_schema_extra is not None + and type(self.resource.json_schema_extra) is dict + ): + if "plugin" in self.resource.json_schema_extra: + plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] + if not issubclass(plugin_resource, ResourcePlugin_RestResourceBase): + raise RestResourcePluginException_InvalidPluginSignature() + self.parent.annotation._plugins_[self.resource_name] = plugin_resource + # print("ADD RESOURCE PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + # print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") + + +class RestResourceWalker_Root__tree_init(RestResourceWalker_Root): + cls_RestResourceWalker_Sub = [ + RestResourceWalker_Sub_T_Dict__tree_init, + RestResourceWalker_Sub_RestFields__tree_init, + RestResourceWalker_Sub_RestResourceBase__tree_init, + ] + + +def register_rest_rootpoint(klass: type[RestResourceBase]): + RestResourceWalker_Root__tree_init(klass).process() + return klass diff --git a/src/pyrestresource/rest_resource_walker.py b/src/pyrestresource/rest_resource_walker.py index 808eb5d..a72a834 100644 --- a/src/pyrestresource/rest_resource_walker.py +++ b/src/pyrestresource/rest_resource_walker.py @@ -1,26 +1,23 @@ from __future__ import annotations - from typing import ( cast, - Any, - Optional, Union, get_args, get_origin, TypeVar, + Type, Generic, TYPE_CHECKING, ) -from typing import Type -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod from pydantic.fields import FieldInfo from .rest_types import _T_SupportedRESTFields +from .rest_resource import RestResourceBase - -if TYPE_CHECKING: - from .rest_resource import RestResourceBase +if TYPE_CHECKING is True: + from typing import Any, Optional TV_RestResourceWalkerFutureResult = TypeVar("TV_RestResourceWalkerFutureResult") @@ -42,7 +39,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): @classmethod @abstractmethod - def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: """implementation interface to Factory. The factory will call this specialized method on each implementation to find a supported one. """ @@ -53,7 +50,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): self, subs: list[type[RestResourceWalker_Sub]], resource_name: str, - resource: FieldInfo | Type["RestResourceBase"], + resource: FieldInfo | Type[RestResourceBase], parent: Optional[RestResourceWalker_Sub] = None, argument: Optional[any] = None, ) -> Optional[RestResourceWalker_Sub]: @@ -68,15 +65,15 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): def __init__( self, resource_name: str, - resource: FieldInfo | Type["RestResourceBase"], + resource: FieldInfo | Type[RestResourceBase], parent: Optional[RestResourceWalker_Sub] = None, - annotation: Optional[type["RestResourceBase"]] = None, + annotation: Optional[type[RestResourceBase]] = None, _optional: Optional[bool] = None, argument: Optional[any] = None, ): self.argument: any = argument self.resource_name: str = resource_name - self.resource: FieldInfo | Type["RestResourceBase"] = resource + self.resource: FieldInfo | Type[RestResourceBase] = resource self.parent: Optional[RestResourceWalker_Sub] = parent self.future_results_subs: Optional[list[RestResourceWalkerFutureResult[TV_RestResourceWalkerFutureResult]]] = None @@ -85,7 +82,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): self.future_results_subs = [] self.future_result = self.cls_RestResourceWalkerFutureResult(self) - self.annotation: type["RestResourceBase"] + self.annotation: type[RestResourceBase] self.optional: bool if annotation is None or _optional is None: self.annotation, self.optional = self.ProcessAnnotation(resource) @@ -151,9 +148,9 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): @staticmethod def ProcessAnnotation( - resource: FieldInfo | Type["RestResourceBase"], + resource: FieldInfo | Type[RestResourceBase], ) -> tuple[type[Any], bool]: - from .rest_resource import RestResourceBase + # from .rest_resource import RestResourceBase _anno: Type[Any] @@ -186,7 +183,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): class RestResourceWalker_Sub_T_Dict(RestResourceWalker_Sub): @classmethod - def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: _anno, _optional = cls.ProcessAnnotation(resource) _type_resource = get_origin(_anno) return (_type_resource is dict), _anno, _optional @@ -202,7 +199,7 @@ class RestResourceWalker_Sub_T_Dict(RestResourceWalker_Sub): class RestResourceWalker_Sub_RestFields(RestResourceWalker_Sub): @classmethod - def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: _anno, _optional = cls.ProcessAnnotation(resource) return (_anno in _T_SupportedRESTFields), _anno, _optional @@ -212,9 +209,7 @@ class RestResourceWalker_Sub_RestFields(RestResourceWalker_Sub): class RestResourceWalker_Sub_RestResourceBase(RestResourceWalker_Sub): @classmethod - def check_type(cls, resource: FieldInfo | Type["RestResourceBase"]) -> tuple[bool, Type[Any], bool]: - from .rest_resource import RestResourceBase - + def check_type(cls, resource: FieldInfo | Type[RestResourceBase]) -> tuple[bool, Type[Any], bool]: _anno, _optional = cls.ProcessAnnotation(resource) return ( ((get_origin(_anno) is None) and issubclass(_anno, RestResourceBase)), @@ -236,11 +231,9 @@ class RestResourceWalker_Root: RestResourceWalker_Sub_RestResourceBase, ] - def __init__(self, resource: "RestResourceBase" | Type["RestResourceBase"]) -> None: + def __init__(self, resource: RestResourceBase | Type[RestResourceBase]) -> None: self.subwalker_argument: any = None - from .rest_resource import RestResourceBase - - self.resource: Type["RestResourceBase"] + self.resource: Type[RestResourceBase] if isinstance(resource, RestResourceBase): self.resource = type(resource) else: @@ -256,7 +249,7 @@ class RestResourceWalker_Root: if sub_walker_initial is not None: sub_walker_initial.process() sub_walker_initial.get_future() - resource_list: list[tuple[str, FieldInfo | Type["RestResourceBase"], RestResourceWalker_Sub]] = [ + resource_list: list[tuple[str, FieldInfo | Type[RestResourceBase], RestResourceWalker_Sub]] = [ (subresource_name, subresource, sub_walker_initial) for subresource_name, subresource in sub_walker_initial.get_sub_resources() ] diff --git a/src/pyrestresource/rest_types.py b/src/pyrestresource/rest_types.py index 98edc7c..6e0a7f8 100644 --- a/src/pyrestresource/rest_types.py +++ b/src/pyrestresource/rest_types.py @@ -1,14 +1,20 @@ # pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring from __future__ import annotations -from enum import Enum, auto from typing import Union, get_origin, NewType, TypeVar, TYPE_CHECKING + +from enum import Enum, auto from datetime import datetime from pathlib import Path from uuid import UUID from ipaddress import IPv4Address, IPv4Network -if TYPE_CHECKING: - from .rest_resource import RestResourceBase +if TYPE_CHECKING is True: + pass + + +class RestResourceException(Exception): + pass + T_Gen_DictKeys: type = type({}.keys()) NoneType = type(None) diff --git a/test/test_rest_login.py b/test/test_rest_login.py index 96e451a..f3fb471 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -12,6 +12,7 @@ import socket import requests from contextlib import closing from multiprocessing import Process +from requests.adapters import HTTPAdapter print(__name__) print(__package__) @@ -43,20 +44,30 @@ chdir(testdir_path.parent.resolve()) # to allow mock-ing, all the tested classes are in a function def init_classes(): - user_CHACHA = UserLogin(username="chacha", secret="123456") + user_test = UserLogin(username="TestUser", secret="123456") + + class TestResource(RestResourceBase): + test_field: Optional[str] = Field("ORIGIN_VALUE") class TestResourceACL(RestResourceBase): test_field: Optional[str] = Field( "ORIGIN_VALUE", ACL=[ - ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user(name="chacha"), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test), rule=ACL_rule.ALLOW), ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), ], ) @register_rest_rootpoint class RootApp(RestResourceBaseLogin): - _ar_user_login: ClassVar[list[UserLogin]] = [user_CHACHA] + _ar_user_login: ClassVar[list[UserLogin]] = [user_test] + test_resourceACL: TestResource = Field( + TestResource(), + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user(name=user_test.username), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) test_resource: TestResourceACL = TestResourceACL() # this add the classes to globals to allow using them later on @@ -73,21 +84,16 @@ def find_free_port(): def launch_server(ip, port): - print(f"port2={port}") init_classes() uvicorn.run(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True) -class Test_RestAPI_LOGIN(unittest.TestCase): +class Test_RestAPI_LOGIN_Web(unittest.TestCase): def setUp(self) -> None: chdir(testdir_path.parent.resolve()) - init_classes() - self.testapp = RootApp() - def test_access(self): + def test_login(self): ip, port = find_free_port() - print(f"ip1={ip}") - print(f"port1={port}") proc = Process( target=launch_server, args=( @@ -98,6 +104,139 @@ class Test_RestAPI_LOGIN(unittest.TestCase): proc.start() sleep(1) s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "TestUser"}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TestUser") + + finally: + proc.terminate() + s.close() + + def test_access_resourceACL(self): + ip, port = find_free_port() + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # before modification read + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resourceACL/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 500) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resourceACL", json={"test_field": "TEST SET VALUE"}) + self.assertEqual(response.status_code, 500) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resourceACL/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + # authenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resourceACL", json={"test_field": "TEST SET VALUE 2"}) + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resourceACL/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") + + finally: + proc.terminate() + s.close() + + def test_access_fieldACL(self): + ip, port = find_free_port() + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + try: # before modification read response = s.get( @@ -106,8 +245,19 @@ class Test_RestAPI_LOGIN(unittest.TestCase): self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), "ORIGIN_VALUE") - # try unauthenticated write - response = s.put(f"http://{ip}:{port}/test_resource/test_field", json='"TEST SET VALUE"') + # try unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 500) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resource", json={"test_field": "TEST SET VALUE"}) self.assertEqual(response.status_code, 500) # check not modified @@ -120,11 +270,11 @@ class Test_RestAPI_LOGIN(unittest.TestCase): # login response = s.put( f"http://{ip}:{port}/login", - json={"username": "chacha", "secret": "123456"}, + json={"username": "TestUser", "secret": "123456"}, ) self.assertEqual(response.status_code, 201) - # authenticated write + # authenticated write (to field) response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") self.assertEqual(response.status_code, 201) @@ -135,73 +285,16 @@ class Test_RestAPI_LOGIN(unittest.TestCase): self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), "TEST SET VALUE") - finally: - proc.terminate() - s.close() - - def test_login(self): - result = self.testapp.process_request("/login", rsrc_verb.GET) - print("*****************") - print(result.get_result()) - - result = self.testapp.process_request("/login/username", rsrc_verb.GET) - print("*****************") - print(result.get_result()) - - # result = self.testapp.process_request("/login/secret", rsrc_verb.GET) - # print("*****************") - # print(result.get_result()) - - result = self.testapp.process_request("/login", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') - print("*****************") - print(result.get_result()) - - result = self.testapp.process_request("/login", rsrc_verb.GET) - print("*****************") - print(result.get_result()) - - result = self.testapp.process_request("/login/username", rsrc_verb.GET) - print("*****************") - print(result.get_result()) - - # result = self.testapp.process_request("/login/secret", rsrc_verb.GET) - # print("*****************") - # print(result.get_result()) - - -class Test_RestAPI_LOGIN_Web(unittest.TestCase): - def setUp(self) -> None: - chdir(testdir_path.parent.resolve()) - - def test_login(self): - ip, port = find_free_port() - print(f"ip1={ip}") - print(f"port1={port}") - proc = Process( - target=launch_server, - args=( - ip, - port, - ), - ) - proc.start() - sleep(1) - s = requests.Session() - try: - # Login in - - response = s.put( - f"http://{ip}:{port}/login", - json={"username": "chacha", "secret": "123456"}, - ) - print(response) - print("??????") - print(response.headers) + # authenticated write (to resource) + response = s.put(f"http://{ip}:{port}/test_resource", json={"test_field": "TEST SET VALUE 2"}) self.assertEqual(response.status_code, 201) - response = s.get(f"http://{ip}:{port}/login") - - response = s.get(f"http://{ip}:{port}/") + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") finally: proc.terminate() diff --git a/test/test_rest_resource_plugins.py b/test/test_rest_resource_plugins.py index cffa355..e797b13 100644 --- a/test/test_rest_resource_plugins.py +++ b/test/test_rest_resource_plugins.py @@ -16,6 +16,7 @@ from src.pyrestresource import ( T_SupportedRESTFields, ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, + RestResourcePluginException_InvalidPluginSignature, ) testdir_path = Path(__file__).parent.resolve() @@ -34,6 +35,7 @@ def init_classes(): class ResourcePlugin_Info(ResourcePlugin_RestResourceBase_default): def handle_resource_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get: + print("HOOK GET !!") return Info_get(version="65.45", api_version="98.321") class Info_get(RestResourceBase): @@ -69,41 +71,9 @@ def init_classes(): def init_bad_plugin1(): - # plugin with missing handle_resource_put() method + # plugin not inheriting from the right base type class ResourcePlugin_TestResource: - def handle_field_get(self, resource: TestResource, params: RestRequestParams_GET) -> TestResource: - return resource - - class TestResource(RestResourceBase): - tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] - - @register_rest_rootpoint - class RootApp2(RestResourceBase): - test: TestResource = Field(default=TestResource(tetvaluestr="testvalue")) - - RootApp2() - - -def init_bad_plugin2(): - # plugin with missing handle_resource_get() method - class ResourcePlugin_TestResource: - def handle_field_put(self, resource: TestResource, params: RestRequestParams_PUT) -> TestResource: - return resource - - class TestResource(RestResourceBase): - tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] - - @register_rest_rootpoint - class RootApp2(RestResourceBase): - test: TestResource = Field(default=TestResource(tetvaluestr="testvalue")) - - RootApp2() - - -def init_bad_plugin3(): - # wrong plugin - class ResourcePlugin_TestResource(ResourcePlugin_RestResourceBase_default): - pass + ... class TestResource(RestResourceBase): tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)] @@ -200,9 +170,5 @@ class Test_RestAPI_Plugin_GET(unittest.TestCase): self.assertEqual(result.get_result(), '"1.5.6"') def test_defect_plugin_field(self): - with self.assertRaises(RuntimeError): + with self.assertRaises(RestResourcePluginException_InvalidPluginSignature): init_bad_plugin1() - with self.assertRaises(RuntimeError): - init_bad_plugin2() - with self.assertRaises(RuntimeError): - init_bad_plugin3() diff --git a/test/test_rest_webserver.py b/test/test_rest_webserver.py index 6eebc21..9329214 100644 --- a/test/test_rest_webserver.py +++ b/test/test_rest_webserver.py @@ -13,7 +13,7 @@ import socket import requests from contextlib import closing from multiprocessing import Process - +from requests.adapters import HTTPAdapter print(__name__) print(__package__) @@ -121,7 +121,6 @@ def find_free_port(): def launch_server(ip, port): - print(f"port2={port}") init_classes() uvicorn.run(f"{__loader__.name}:RootApp", port=port, host="0.0.0.0", log_level="warning", factory=True) @@ -132,8 +131,6 @@ class Test_RestAPI_WebServer(unittest.TestCase): def test_nomal_AllCmd_games(self): ip, port = find_free_port() - print(f"ip1={ip}") - print(f"port1={port}") proc = Process( target=launch_server, args=( @@ -144,6 +141,8 @@ class Test_RestAPI_WebServer(unittest.TestCase): proc.start() sleep(1) s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + try: # Fetching games response = s.get(f"http://{ip}:{port}/games") @@ -291,8 +290,6 @@ class Test_RestAPI_WebServer(unittest.TestCase): n_loop = 10000 ip, port = find_free_port() - print(f"ip1={ip}") - print(f"port1={port}") proc = Process( target=launch_server, args=( @@ -303,6 +300,8 @@ class Test_RestAPI_WebServer(unittest.TestCase): proc.start() sleep(1) s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + try: start = time() for _ in range(n_loop):