From f00cf7b4b2d60fc3eb55240eef931fca77b46277 Mon Sep 17 00:00:00 2001 From: cclecle Date: Sun, 5 Nov 2023 15:38:08 +0000 Subject: [PATCH] continuing implementation of login and session --- .settings/org.eclipse.core.resources.prefs | 1 + src/pyrestresource/__init__.py | 1 + src/pyrestresource/helpers.py | 14 +++ src/pyrestresource/rest_ACL.py | 4 - src/pyrestresource/rest_login.py | 127 ++++++++++++++++++++ src/pyrestresource/rest_request.py | 32 ++++- src/pyrestresource/rest_resource.py | 54 ++++++--- src/pyrestresource/rest_resource_handler.py | 34 +++--- src/pyrestresource/rest_resource_plugin.py | 37 ++++-- test/test_ACL.py | 16 +-- test/test_rest_login.py | 118 ++++++++++-------- 11 files changed, 325 insertions(+), 113 deletions(-) create mode 100644 src/pyrestresource/rest_login.py diff --git a/.settings/org.eclipse.core.resources.prefs b/.settings/org.eclipse.core.resources.prefs index e658763..c89f8d0 100644 --- a/.settings/org.eclipse.core.resources.prefs +++ b/.settings/org.eclipse.core.resources.prefs @@ -1,6 +1,7 @@ eclipse.preferences.version=1 encoding//src/pyrestresource/__init__.py=utf-8 encoding//src/pyrestresource/__metadata__.py=utf-8 +encoding//src/pyrestresource/rest_login.py=utf-8 encoding//src/pyrestresource/rest_resource.py=utf-8 encoding//src/pyrestresource/rest_resource_handler_walker.py=utf-8 encoding/=UTF-8 diff --git a/src/pyrestresource/__init__.py b/src/pyrestresource/__init__.py index 15d39df..3722afd 100644 --- a/src/pyrestresource/__init__.py +++ b/src/pyrestresource/__init__.py @@ -54,3 +54,4 @@ from .rest_resource_plugin import ( ResourcePlugin_dict_default, ) from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule +from .rest_login import RestResourceBaseLogin, UserLogin diff --git a/src/pyrestresource/helpers.py b/src/pyrestresource/helpers.py index a39262c..a08067b 100644 --- a/src/pyrestresource/helpers.py +++ b/src/pyrestresource/helpers.py @@ -15,3 +15,17 @@ class _JSONEncoder(json.JSONEncoder): # if the obj is uuid, we simply return the value of uuid return str(o) return json.JSONEncoder.default(self, o) + + +def parse_dict_cookies(cookies: str) -> dict[str, str]: + result = {} + for item in cookies.split(";"): + item = item.strip() + if not item: + continue + if "=" not in item: + result[item] = None + continue + name, value = item.split("=", 1) + result[name] = value + return result diff --git a/src/pyrestresource/rest_ACL.py b/src/pyrestresource/rest_ACL.py index a75fa8b..3ebdbaa 100644 --- a/src/pyrestresource/rest_ACL.py +++ b/src/pyrestresource/rest_ACL.py @@ -22,10 +22,6 @@ class ACL_target_group(ACL_target): name: str -class ACL_target_group_Annonymous(ACL_target): - name: str = "__ANNONYMOUS__" - - class ACL_target_group_Any(ACL_target_group): name: str = "__ANY__" diff --git a/src/pyrestresource/rest_login.py b/src/pyrestresource/rest_login.py new file mode 100644 index 0000000..0fa0c2e --- /dev/null +++ b/src/pyrestresource/rest_login.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# pyrestresource(c) by chacha +# +# pyrestresource is licensed under a +# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License. +# +# You should have received a copy of the license along with this +# work. If not, see . + +# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring + +"""CLI interface module""" +from __future__ import annotations + +from typing import Optional, ClassVar, TYPE_CHECKING +from secrets import token_hex, compare_digest +from datetime import datetime + +from pydantic import BaseModel, Field + +from .rest_types import rsrc_verb +from .rest_resource import RestResourceBase + +from .rest_request import RestRequest, RestRequestParams_GET +from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule + +if TYPE_CHECKING or True: + from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default + + +class UserLogin(BaseModel): + username: str + secret: str + + +class UserSession(BaseModel): + last_update: datetime + user_login: UserLogin + host: Optional[str] + + +class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): + ar_UserLogin: list[UserLogin] = [] + + def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login: + print("hook GET") + print(resource) + print(params) + return resource + + def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login: + print("hook PUT") + # print(self.get_ar_userlogin()) + print(resource.username) + print(resource.secret) + + token = self.user_login(resource.username, resource.secret) + self.set_resp_cookie_value("Authorization", f"Bearer {token}") + + return resource + + +class Login(RestResourceBase): + username: Optional[str] = Field(None) + secret: Optional[str] = Field( + None, + exclude=True, + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + + +class RestResourceBaseLogin(RestResourceBase): + _ar_user_login: ClassVar[list[UserLogin]] = [] + _ar_user_session: dict[str, UserSession] = {} + _max_session_time_minutes: ClassVar[int] = 20 + login: Login = Field(default=Login(), plugin=ResourcePlugin_Login) + + def _process_request_session(self, request: RestRequest) -> None: + auth_cookie = request.get_cookie("Authorization") + if auth_cookie != None: + if auth_cookie in self._ar_user_session: + print("USER SESSION FOUND !") + print(self._ar_user_session[auth_cookie].user_login.username) + print(auth_cookie) + + time_diff_min = (datetime.now() - self._ar_user_session[auth_cookie].last_update).total_seconds() / 60 + + if time_diff_min > self._max_session_time_minutes: + del self._ar_user_session[auth_cookie] + raise RuntimeError("session timeout ! (session reseted)") + + request.set_user(self._ar_user_session[auth_cookie].user_login.username) + return + + print("Invalid session") + return + + print("non-connected user") + + def user_login(self, user_name: str, user_secret: str, request: RestRequest) -> str: + already_failed: bool = False + + for iter_user_login in self._ar_user_login: + username_ok: bool = compare_digest(user_name, iter_user_login.username) + secret_ok: bool = compare_digest(user_secret, iter_user_login.secret) + + if username_ok is True: + if secret_ok is True and not already_failed: + return self._register_user_session(iter_user_login, request) + else: + already_failed = True + else: + pass + pass + + if already_failed: + raise RuntimeError("Wrong auth") # TODO: specific exception + + def _register_user_session(self, user_login: UserLogin, request: RestRequest) -> str: + token = token_hex(16) + new_user_session = UserSession(last_update=datetime.now(), user_login=user_login, host=request.get_host()) + self._ar_user_session[f"Bearer {token}"] = new_user_session + return token diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index 4549d0e..c5d08cd 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -3,11 +3,14 @@ from __future__ import annotations from typing import ( + Any, Optional, Generic, ) from re import sub from urllib.parse import urlparse, parse_qs +from http.cookies import SimpleCookie + from pydantic import BaseModel, Field from typeguard import check_type @@ -26,7 +29,8 @@ from .rest_request_opt import ( _T_RestRequestParams_PUT, ) -from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group, ACL_target_group_Annonymous +from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group +from .helpers import parse_dict_cookies class RequestFactory( @@ -114,6 +118,8 @@ class RestRequest(Generic[_T_RestRequestParams]): self.url: str self.verb: rsrc_verb self.data: dict + self.raw_headers: list[Any] + self.headers: dict[str, None | str | dict[str, None | str]] = {"host": None, "cookie": {}} self._saved_url_params: dict self.ReqParams: _T_RestRequestParams = type_request_params() self.url_stack: list[str] @@ -122,7 +128,7 @@ class RestRequest(Generic[_T_RestRequestParams]): self.incoming_cookie: dict[str, str] = incoming_cookie self.outgoing_cookie: dict[str, str] = outgoing_cookie self.user: ACL_target_user = ACL_target_user_Annonymous() - self.group: ACL_target_group = ACL_target_group_Annonymous() + self.groups: list[ACL_target_group] = [] self.result: Optional[str] = None # = or create a fresh one = @@ -151,6 +157,24 @@ class RestRequest(Generic[_T_RestRequestParams]): self._saved_url_stack = self.url_stack.copy() self.url_stack_index = 0 + def set_headers(self, headers: list[Any]) -> None: + self.raw_headers = headers + for elem in self.raw_headers: + if elem[0] == b"host": + self.headers["host"] = elem[1].decode("utf-8") + # elif elem[0] == b"user-agent": + # self.headers["user-agent"] = elem[1].decode("utf-8") + elif elem[0] == b"cookie": + self.headers["cookie"] = parse_dict_cookies(elem[1].decode("utf-8")) + + def get_cookie(self, key: str) -> str | None: + if key not in self.headers["cookie"]: + return None + return self.headers["cookie"][key] + + def get_host(self) -> str: + print(self.headers["host"]) + def set_result(self, result: str): self.result = result @@ -160,8 +184,8 @@ class RestRequest(Generic[_T_RestRequestParams]): def set_user(self, user: ACL_target_user): self.user: ACL_target_user = user - def set_group(self, group: ACL_target_group): - self.group: ACL_target_group = group + def add_group(self, group: ACL_target_group): + self.groups.append(group) def update_ReqParams(self, type_request_params: type[_T_RestRequestParams]): self.ReqParams = type_request_params(**self._saved_url_params) diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index cfad675..75525f5 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -15,6 +15,7 @@ from __future__ import annotations from abc import ABC from typing import ( + Any, cast, ClassVar, get_args, @@ -39,7 +40,6 @@ from .rest_ACL import ( ACL_target_user, ACL_target_group, ACL_target_user_Annonymous, - ACL_target_group_Annonymous, ACL_target_group_Any, ACL_rule, ) @@ -219,8 +219,8 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_ ): if "plugin" in self.resource.json_schema_extra: plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] - if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase): - raise RuntimeError("Wrong plugin signature provided") + if not issubclass(plugin_resource, ResourcePlugin_RestResourceBase): + raise RuntimeError(f"Wrong plugin signature provided for {plugin_resource} : {type(plugin_resource)}") self.parent.annotation._plugins_[self.resource_name] = plugin_resource # print("ADD RESOURCE PLUGIN") @@ -246,7 +246,7 @@ def register_rest_rootpoint(klass: type[RestResourceBase]): class RestResourceBase(ABC, BaseModel, validate_assignment=True): - _resp_cookies: ClassVar[dict[str, str]] = dict() + # _resp_cookies: ClassVar[dict[str, str]] = {} _dict_key_type_: ClassVar[dict[str, T_T_DictKey]] = {} _dict_value_type_: ClassVar[dict[str, T_T_DictValues]] = {} _model_dump_excluded_: ClassVar[dict[str, bool]] = {} @@ -264,43 +264,45 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): ] ] = {} - def _check_acl(self, user: ACL_target_user, group: ACL_target_group, verb: rsrc_verb, field: str, is_self: bool = True): - # print(f"evaluate self ACLs rule: {self._ACL_record_}") + def _check_acl(self, user: str, groups: list[ACL_target_group], verb: rsrc_verb, field: str, is_self: bool = True): + print(f"evaluate self ACLs rule: {self._ACL_record_}") + print(f"user: {user}") + print(f"groups: {groups}") if is_self and verb is rsrc_verb.GET and self.model_fields[field].exclude is True: # print("ALLOWED (excluded field)") return for acl in self._ACL_record_[field]: - # print(f"evaluate ACL rule: {acl}") + print(f"evaluate ACL rule: {acl}") if verb in acl.verbs: if isinstance(acl.target, ACL_target_user): - if user == acl.target: + if user == acl.target.name: if acl.rule is ACL_rule.ALLOW: - # print("ALLOWED (user)") + print("ALLOWED (user)") return raise RuntimeError(f"Not allowed access detected: {field}") elif isinstance(acl.target, ACL_target_group): - if group == acl.target or acl.target == ACL_target_group_Any(): + if acl.target.name in groups or isinstance(acl.target, ACL_target_group_Any): if acl.rule is ACL_rule.ALLOW: - # print("ALLOWED (group)") + print("ALLOWED (group)") return raise RuntimeError(f"Not allowed access detected: {field}") else: raise RuntimeError(f"Wrong ACL target type: {field}") - # print("ALLOWED (Default)") + print("ALLOWED (Default)") def check_acl_field(self, request: RestRequest, req_index: int = 0) -> None: """Check ACL on requested field access""" - self._check_acl(request.user, request.group, request.get_verb(), request.get_resource_origin(req_index), False) + self._check_acl(request.user, request.groups, request.get_verb(), request.get_resource_origin(req_index), False) def check_acl_self(self, request: RestRequest, new_data: Optional[dict[str, _T_SupportedRESTFields]]) -> None: """Check ACL on requested field operation (involving checking sub-fields)""" if request.get_verb() is rsrc_verb.GET: for key in self.model_fields.keys(): - self._check_acl(request.user, request.group, rsrc_verb.GET, key) + self._check_acl(request.user, request.groups, rsrc_verb.GET, key) elif request.get_verb() is rsrc_verb.PUT: for key in new_data.keys(): if key in self.model_fields: - self._check_acl(request.user, request.group, rsrc_verb.PUT, key) + self._check_acl(request.user, request.groups, rsrc_verb.PUT, key) else: raise RuntimeError("Incompatible verb") @@ -324,21 +326,24 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): async def __call__(self, scope, receive, send): assert scope["type"] == "http" + method = scope["method"] + assert method in ["GET", "DELETE", "PUT", "POST"] + if b"content-type" in scope["headers"]: assert scope["headers"][b"content-type"] == b"application/json" - # import pprint + import pprint - # print("----REC HEADER ---") - # pprint.pprint(scope["headers"]) + print("----REC HEADER ---") + pprint.pprint(scope["headers"]) body = await self.read_body(receive) verb = rsrc_verb[scope["method"]] request: RestRequest = self.process_request( - scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8") + scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8"), scope["headers"] ) assert request != None @@ -373,12 +378,16 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): } ) + def _process_request_session(self, request: RestRequest) -> None: + pass + def process_request( self, url: str, verb: rsrc_verb = rsrc_verb.GET, data_json: Optional[str] = None, query_string: Optional[str] = None, + headers: Optional[list[Any]] = None, ) -> RestRequest: from .rest_resource_handler import ( ResourceHandler, @@ -389,11 +398,16 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if data_json: data = json.loads(data_json) + # creating the root handler ressource_handler: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string) + # preparing request & session request: RestRequest = ressource_handler.get_request() - assert request != None + if headers is not None: + request.set_headers(headers) + self._process_request_session(request) + # processing the verb result = ressource_handler.process_verb() # print("OOO") diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index 01ecc0b..5e99bd7 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -23,7 +23,6 @@ from .rest_ACL import ( ACL_target_user, ACL_target_group, ACL_target_user_Annonymous, - ACL_target_group_Annonymous, ACL_target_group_Any, ACL_rule, ) @@ -101,6 +100,7 @@ class ResourceHandler( self.next_handler: Optional[ResourceHandler] = None self.saved_url: list[str] = [] self.resource: _T_Resource = resource + self.root_resource: _T_Resource = resource if prev_handler is None else prev_handler.root_resource self.req: RestRequest if prev_handler is not None: self.prev_handler = prev_handler @@ -484,14 +484,6 @@ class ResourceHandler_RestResourceBase( if len(self.req.get_url_stack()) == 0: # destination reached if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET: raise RuntimeError(f"Not allowed READ access detected: {self.req.get_url_stack()}") - """ # not sure init_var has the expected behavior (read_only) - if self.resource.model_fields[self.req.get_resource_origin(0)].init_var is True and self.req.get_verb() in [ - rsrc_verb.POST, - rsrc_verb.PUT, - rsrc_verb.DELETE, - ]: - raise RuntimeError(f"Not allowed WRITE access detected: {self.req.get_url_stack()}") - """ def _handle_process_get(self, params) -> RestResourceBase: # print(f"{type(self).__name__}->_process_get()") @@ -504,11 +496,15 @@ class ResourceHandler_RestResourceBase( for key, attr in self.resource.model_fields.items(): if key in self.resource._plugins_: if isinstance(self.resource._plugins_[key], ResourcePlugin_field): - plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, self.resource._plugins_[key](self.req)) + plugin_field: ResourcePlugin_field = cast( + ResourcePlugin_field, self.resource._plugins_[key](self.req, self.root_resource) + ) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_field_get(value, params)) elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): - plugin_field: ResourcePlugin_field = cast(ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req)) + plugin_field: ResourcePlugin_field = cast( + ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource) + ) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_resource_get(value, params)) @@ -530,14 +526,14 @@ class ResourceHandler_RestResourceBase( if isinstance(self.resource._plugins_[key], ResourcePlugin_field): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.resource._plugins_[key](self.req), + self.resource._plugins_[key](self.req, self.root_resource), ) value = plugin_rsrc.handle_field_get(value, params) elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.resource._plugins_[key](self.req), + self.resource._plugins_[key](self.req, self.root_resource), ) value = plugin_rsrc.handle_resource_get(value, params) @@ -559,7 +555,9 @@ class ResourceHandler_RestResourceBase( for key, attr in _new_resrc.model_fields.items(): if key in _new_resrc._plugins_: if isinstance(_new_resrc._plugins_[key], ResourcePlugin_field): - plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, _new_resrc._plugins_[key](self.req)) + plugin_field: ResourcePlugin_field = cast( + ResourcePlugin_field, _new_resrc._plugins_[key](self.req, self.root_resource) + ) value = getattr(_new_resrc, key) setattr(_new_resrc, key, plugin_field.handle_field_put(value, params)) @@ -574,7 +572,7 @@ class ResourceHandler_RestResourceBase( if key in self.prev_handler.prev_handler.resource._plugins_: plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.prev_handler.prev_handler.resource._plugins_[key](self.req), + self.prev_handler.prev_handler.resource._plugins_[key](self.req, self.root_resource), ) _new_resrc = plugin_rsrc.handle_dict_elem_put(_new_resrc, params) # element is within a RestResourceBase @@ -583,7 +581,7 @@ class ResourceHandler_RestResourceBase( if key in self.prev_handler.resource._plugins_: plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.prev_handler.resource._plugins_[key](self.req), + self.prev_handler.resource._plugins_[key](self.req, self.root_resource), ) _new_resrc = plugin_rsrc.handle_resource_put(_new_resrc, params) @@ -634,7 +632,7 @@ class ResourceHandler_simple( if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, - self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req), + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource), ) return plugin_simple.handle_field_get(self.resource, params) @@ -655,7 +653,7 @@ class ResourceHandler_simple( # print("PLUGIN FOUND") plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, - self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req), + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource), ) # print(value) value = plugin_simple.handle_field_put(value, params) diff --git a/src/pyrestresource/rest_resource_plugin.py b/src/pyrestresource/rest_resource_plugin.py index 9308aa6..72290f7 100644 --- a/src/pyrestresource/rest_resource_plugin.py +++ b/src/pyrestresource/rest_resource_plugin.py @@ -1,7 +1,7 @@ from __future__ import annotations -from typing import Optional, Protocol, runtime_checkable, TYPE_CHECKING -from abc import abstractmethod +from typing import Optional, Generic, TYPE_CHECKING +from abc import abstractmethod, ABC from .rest_types import ( _T_DictValues, @@ -12,6 +12,7 @@ from .rest_types import ( from .rest_request import RestRequest + if TYPE_CHECKING or True: from .rest_request_opt import ( RestRequestParams_GET, @@ -26,21 +27,33 @@ if TYPE_CHECKING or True: ) -class ResourcePlugin(Protocol): - def __init__(self, request: RestRequest) -> None: - self.request: RestRequest = request +class ResourcePlugin(ABC): + def __init__(self, request: RestRequest, root_resource: "RestResourceBase") -> None: + self.__request: RestRequest = request + self.__root_resource: RestRequest = root_resource - def set_resp_cookie(self, name: str, value: str): + def user_login(self, user_name: str, user_secret: str) -> str: + return self.__root_resource.user_login(user_name, user_secret, self.__request) + + """ + def get_ar_userlogin(self): + print("===========") + return self.__root_resource.get_ar_user_login() + """ + + def getr_req_cookie_value(self, key: str) -> Optional[str]: + return self.__request.incoming_cookie[key] + + def set_resp_cookie_value(self, key: str, value: str): # print("AAA") # print(name) # print(value) # print(self.cookies) # print(type(self.cookies)) - self.request.outgoing_cookie[name] = value + self.__request.outgoing_cookie[key] = value -@runtime_checkable -class ResourcePlugin_field(ResourcePlugin, Protocol[TV_SupportedRESTFields]): +class ResourcePlugin_field(ResourcePlugin, Generic[TV_SupportedRESTFields]): @abstractmethod def handle_field_get(self, resource: TV_SupportedRESTFields, params: RestRequestParams_GET) -> TV_SupportedRESTFields: ... @@ -60,8 +73,7 @@ class ResourcePlugin_field_default(ResourcePlugin_field[TV_SupportedRESTFields]) return resource -@runtime_checkable -class ResourcePlugin_RestResourceBase(ResourcePlugin, Protocol[TV_RestResourceBase]): +class ResourcePlugin_RestResourceBase(ResourcePlugin, Generic[TV_RestResourceBase]): @abstractmethod def handle_resource_get( self, @@ -97,8 +109,7 @@ class ResourcePlugin_RestResourceBase_default(ResourcePlugin_RestResourceBase[TV return resource -@runtime_checkable -class ResourcePlugin_dict(ResourcePlugin, Protocol[_T_DictKey, _T_DictValues]): +class ResourcePlugin_dict(ResourcePlugin, Generic[_T_DictKey, _T_DictValues]): @abstractmethod def handle_dict_get_keys( self, diff --git a/test/test_ACL.py b/test/test_ACL.py index 17859e3..5e69df8 100644 --- a/test/test_ACL.py +++ b/test/test_ACL.py @@ -59,7 +59,7 @@ def init_classes(): resource_with_secret_ACL: TestResource = Field( default=TestResource(), ACL=[ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY)] ) - resource2: TestResource2 = Field(TestResource2()) + resource_ro: TestResource2 = Field(TestResource2()) # this add the classes to globals to allow using them later on # => this is only for uinit-testing purpose and is not needed in real use @@ -77,21 +77,23 @@ class Test_RestAPI_ACL(unittest.TestCase): result = self.testapp.process_request("/", rsrc_verb.GET) self.assertEqual(result.get_result(), "{}") - result = self.testapp.process_request("/resource2", rsrc_verb.GET) + result = self.testapp.process_request("/resource_ro", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "3.2.1"}') - self.testapp.process_request("/resource2/version", rsrc_verb.PUT, '"6.6.6"') + self.testapp.process_request("/resource_ro/version", rsrc_verb.PUT, '"6.6.6"') - result = self.testapp.process_request("/resource2", rsrc_verb.GET) + result = self.testapp.process_request("/resource_ro", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') with self.assertRaises(RuntimeError): # TODO: custom exception - self.testapp.process_request("/resource2/version_ro", rsrc_verb.PUT, '"6.6.6"') + self.testapp.process_request("/resource_ro/version_ro", rsrc_verb.PUT, '"6.6.6"') + self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3") with self.assertRaises(RuntimeError): # TODO: custom exception - self.testapp.process_request("/resource2", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}') + self.testapp.process_request("/resource_ro", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}') + self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3") - result = self.testapp.process_request("/resource2", rsrc_verb.GET) + result = self.testapp.process_request("/resource_ro", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') def test_subresource(self): diff --git a/test/test_rest_login.py b/test/test_rest_login.py index ff71667..96e451a 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -3,7 +3,7 @@ import unittest from unittest.mock import patch from os import chdir from pathlib import Path -from typing import Optional, Annotated +from typing import Optional, Annotated, ClassVar from pydantic import Field from uuid import UUID, uuid4 from time import time, sleep @@ -12,16 +12,17 @@ import socket import requests from contextlib import closing from multiprocessing import Process -from secrets import token_hex print(__name__) print(__package__) -from pydantic import BaseModel from src.pyrestresource import ( - register_rest_rootpoint, + ACL_target_user, + UserLogin, RestResourceBase, + RestResourceBaseLogin, + register_rest_rootpoint, rsrc_verb, RestRequestParams_GET, RestRequestParams_POST, @@ -42,58 +43,25 @@ chdir(testdir_path.parent.resolve()) # to allow mock-ing, all the tested classes are in a function def init_classes(): - class UserLogin(BaseModel): - username: str - secret: str - token: Optional[str] = None + user_CHACHA = UserLogin(username="chacha", secret="123456") - class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): - ar_UserLogin: list[UserLogin] = [UserLogin(username="chacha", secret="123456")] - - def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login: - print("hook GET") - print(resource) - print(params) - return resource - - def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login: - print("hook PUT") - - print(resource.username) - print(resource.secret) - - for _UserLogin in self.ar_UserLogin: - if _UserLogin.username == resource.username and _UserLogin.secret == resource.secret: - print("user connected") - _UserLogin.token = token_hex(16) - self.set_resp_cookie("test", _UserLogin.token) - print(f"generated token: {_UserLogin.token}") - return resource - print("login NOT found") - # print(resource) - # print(resource.username) - # print(resource.secret) - # print(params) - return resource - - class Login(RestResourceBase): - username: Optional[str] = Field(None) - secret: Optional[str] = Field( - None, - exclude=True, + class TestResourceACL(RestResourceBase): + test_field: Optional[str] = Field( + "ORIGIN_VALUE", ACL=[ - ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), - ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user(name="chacha"), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), ], ) @register_rest_rootpoint - class RootApp(RestResourceBase): - login: Login = Field(default=Login(), plugin=ResourcePlugin_Login) + class RootApp(RestResourceBaseLogin): + _ar_user_login: ClassVar[list[UserLogin]] = [user_CHACHA] + test_resource: TestResourceACL = TestResourceACL() # this add the classes to globals to allow using them later on # => this is only for uinit-testing purpose and is not needed in real use - globals()[Login.__name__] = Login + globals()[TestResourceACL.__name__] = TestResourceACL globals()[RootApp.__name__] = RootApp @@ -116,6 +84,61 @@ class Test_RestAPI_LOGIN(unittest.TestCase): init_classes() self.testapp = RootApp() + def test_access(self): + ip, port = find_free_port() + print(f"ip1={ip}") + print(f"port1={port}") + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + try: + # before modification read + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # try unauthenticated write + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json='"TEST SET VALUE"') + self.assertEqual(response.status_code, 500) + + # check not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "chacha", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # authenticated write + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + finally: + proc.terminate() + s.close() + def test_login(self): result = self.testapp.process_request("/login", rsrc_verb.GET) print("*****************") @@ -172,6 +195,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): json={"username": "chacha", "secret": "123456"}, ) print(response) + print("??????") print(response.headers) self.assertEqual(response.status_code, 201)