diff --git a/src/pyrestresource/__init__.py b/src/pyrestresource/__init__.py index d129cc3..08402e3 100644 --- a/src/pyrestresource/__init__.py +++ b/src/pyrestresource/__init__.py @@ -32,7 +32,6 @@ if TYPE_CHECKING: T_T_DictKey, T_DictValues, T_T_DictValues, - RestResourceException, ) from .rest_request_opt import ( @@ -51,8 +50,20 @@ from .rest_resource_plugin import ( ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, ResourcePlugin_dict_default, - RestResourcePluginException, - RestResourcePluginException_InvalidPluginSignature, ) from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule -from .rest_login import RestResourceBaseLogin, UserLogin +from .rest_login import ( + RestResourceBaseLogin, + UserLogin, +) + +from .rest_exceptions import ( + RestResourceException, + RestResourceLoginException, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_ClientChange, + RestResourceLoginException_InvalidSession, + RestResourcePluginException, + RestResourcePluginException_InvalidPluginSignature, + RestResourceHandlerException_Forbiden, +) diff --git a/src/pyrestresource/helpers.py b/src/pyrestresource/helpers.py index e8c13d1..2d2723e 100644 --- a/src/pyrestresource/helpers.py +++ b/src/pyrestresource/helpers.py @@ -4,6 +4,7 @@ from __future__ import annotations from uuid import UUID import json +import traceback from .rest_types import T_Gen_DictKeys @@ -30,3 +31,10 @@ def parse_dict_cookies(cookies: str) -> dict[str, str]: name, value = item.split("=", 1) result[name] = value return result + + +def forward_exception(e: Exception, forward: bool) -> None: + if forward: + raise e from None + else: + traceback.print_exc() diff --git a/src/pyrestresource/rest_exceptions.py b/src/pyrestresource/rest_exceptions.py new file mode 100644 index 0000000..c779dac --- /dev/null +++ b/src/pyrestresource/rest_exceptions.py @@ -0,0 +1,58 @@ +class RestResourceException(Exception): + pass + + +class RestResourceModelException(RestResourceException): + pass + + +class RestResourceModelException_ACL(RestResourceModelException): + pass + + +class RestResourceHandlerException(RestResourceException): + pass + + +class RestResourceHandlerException_ResourceNotFound(RestResourceHandlerException): + pass + + +class RestResourceHandlerException_MethodNotAllowed(RestResourceHandlerException): + pass + + +class RestResourceHandlerException_BadRequest(RestResourceHandlerException): + pass + + +class RestResourceHandlerException_Forbiden(RestResourceHandlerException): + pass + + +class RestResourceLoginException(RestResourceException): + pass + + +class RestResourceLoginException_SessionTimeout(RestResourceLoginException): + pass + + +class RestResourceLoginException_ClientChange(RestResourceLoginException): + pass + + +class RestResourceLoginException_InvalidSession(RestResourceLoginException): + pass + + +class RestResourceLoginException_InvalidCredentials(RestResourceLoginException): + pass + + +class RestResourcePluginException(RestResourceException): + pass + + +class RestResourcePluginException_InvalidPluginSignature(RestResourcePluginException): + pass diff --git a/src/pyrestresource/rest_login.py b/src/pyrestresource/rest_login.py index 707e421..f91de8c 100644 --- a/src/pyrestresource/rest_login.py +++ b/src/pyrestresource/rest_login.py @@ -15,13 +15,19 @@ from __future__ import annotations from typing import Optional, ClassVar, TYPE_CHECKING from secrets import token_hex, compare_digest -from datetime import datetime +from datetime import datetime, timedelta from pydantic import BaseModel, Field from .rest_types import rsrc_verb from .rest_resource import RestResourceBase from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule, ACL_target_user from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default +from .rest_exceptions import ( + RestResourceLoginException_InvalidCredentials, + RestResourceLoginException_ClientChange, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_InvalidSession, +) if TYPE_CHECKING is True: from .rest_request import RestRequest, RestRequestParams_GET @@ -35,7 +41,7 @@ class UserLogin(BaseModel): class UserSession(BaseModel): last_update: datetime user_login: UserLogin - host: Optional[str] + client: Optional[tuple[str, int]] class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): @@ -65,30 +71,38 @@ class Login(RestResourceBase): class RestResourceBaseLogin(RestResourceBase): _ar_user_login: ClassVar[list[UserLogin]] = [] _ar_user_session: dict[str, UserSession] = {} - _max_session_time_minutes: ClassVar[int] = 20 + _max_session_inactive: ClassVar[timedelta] = timedelta(minutes=20) + _max_session_time: ClassVar[timedelta] = timedelta(hours=12) login: Login = Field(default=Login(), plugin=ResourcePlugin_Login) + def get_new_cookie_expiration_date(self) -> datetime: + return datetime.now() + self._max_session_time + def _process_request_session(self, request: RestRequest) -> None: + # print(f"[TRACE] {type(self).__name__}->_process_request_session()") + # print(f"[TRACE] request: {id(request)}") auth_cookie = request.get_cookie("Authorization") if auth_cookie != None: if auth_cookie in self._ar_user_session: - print("USER SESSION FOUND !") - print(self._ar_user_session[auth_cookie].user_login.username) - print(auth_cookie) + # print(f"SESSION FOUND for {request.get_client()}") - time_diff_min = (datetime.now() - self._ar_user_session[auth_cookie].last_update).total_seconds() / 60 - - if time_diff_min > self._max_session_time_minutes: + if self._ar_user_session[auth_cookie].client != request.get_client(): del self._ar_user_session[auth_cookie] - raise RuntimeError("session timeout ! (session reseted)") + raise RestResourceLoginException_ClientChange() + + time_diff = datetime.now() - self._ar_user_session[auth_cookie].last_update + if time_diff > self._max_session_inactive: + del self._ar_user_session[auth_cookie] + raise RestResourceLoginException_SessionTimeout() request.set_user(ACL_target_user(name=self._ar_user_session[auth_cookie].user_login.username)) + # print("SESSION RECOVERED") return - print("Invalid session") + raise RestResourceLoginException_InvalidSession() return - print("non-connected user") + # print(f"non-connected user {request.get_client()}") def user_login(self, user_name: str, user_secret: str, request: RestRequest) -> str: already_failed: bool = False @@ -107,10 +121,10 @@ class RestResourceBaseLogin(RestResourceBase): pass if already_failed: - raise RuntimeError("Wrong auth") # TODO: specific exception + raise RestResourceLoginException_InvalidCredentials() def _register_user_session(self, user_login: UserLogin, request: RestRequest) -> str: token = token_hex(16) - new_user_session = UserSession(last_update=datetime.now(), user_login=user_login, host=request.get_host()) + new_user_session = UserSession(last_update=datetime.now(), user_login=user_login, client=request.get_client()) self._ar_user_session[f"Bearer {token}"] = new_user_session return token diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index 0d37b7d..e7dd54c 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -31,6 +31,7 @@ from .helpers import parse_dict_cookies if TYPE_CHECKING is True: from typing import Optional from .rest_types import T_SupportedRESTFields + from .rest_resource import RestResourceBase class RequestFactory( @@ -49,7 +50,9 @@ class RequestFactory( cls_RestRequestParams_POST: type[RestRequestParams_POST] = Field(default=RestRequestParams_POST) cls_RestRequestParams_DELETE: type[RestRequestParams_DELETE] = Field(default=RestRequestParams_DELETE) - def get_RestRequest(self, url: str, verb: rsrc_verb, data: dict, query_string: Optional[str] = None) -> RestRequest: + def get_RestRequest( + self, root_resource: RestResourceBase, url: str, verb: rsrc_verb, data: dict, query_string: Optional[str] = None + ) -> RestRequest: """get a RestRequets instance based on LUT_verb configuration Args: @@ -60,14 +63,14 @@ class RequestFactory( # /!\ mypy seems not being able to propagate typevar to composed classes if verb is rsrc_verb.GET: - return RestRequest[RestRequestParams_GET](self.cls_RestRequestParams_GET, url, verb, data, query_string) + return RestRequest[RestRequestParams_GET](self.cls_RestRequestParams_GET, root_resource, url, verb, data, query_string) if verb is rsrc_verb.PUT: - return RestRequest[RestRequestParams_PUT](self.cls_RestRequestParams_PUT, url, verb, data, query_string) + return RestRequest[RestRequestParams_PUT](self.cls_RestRequestParams_PUT, root_resource, url, verb, data, query_string) if verb is rsrc_verb.POST: - return RestRequest[RestRequestParams_POST](self.cls_RestRequestParams_POST, url, verb, data, query_string) + return RestRequest[RestRequestParams_POST](self.cls_RestRequestParams_POST, root_resource, url, verb, data, query_string) if verb is rsrc_verb.DELETE: - return RestRequest[RestRequestParams_DELETE](self.cls_RestRequestParams_DELETE, url, verb, data, query_string) - raise RuntimeError("Invalid Verb") + return RestRequest[RestRequestParams_DELETE](self.cls_RestRequestParams_DELETE, root_resource, url, verb, data, query_string) + raise RestResourceHandlerException_MethodNotAllowed("Invalid Verb") def update_RestRequest(self, request: RestRequest) -> None: """create an updated copy of a RestRequest object based on a different LUT_verb configuration @@ -85,7 +88,7 @@ class RequestFactory( elif request.verb is rsrc_verb.DELETE: request.update_ReqParams(self.cls_RestRequestParams_DELETE) else: - raise RuntimeError("Invalid Verb") + raise RestResourceHandlerException_MethodNotAllowed("Invalid Verb") return @@ -96,12 +99,11 @@ class RestRequest(Generic[_T_RestRequestParams]): def __init__( self, type_request_params: type[_T_RestRequestParams], + root_resource: RestResourceBase, url: str, verb: rsrc_verb, data: Optional[dict[str, T_SupportedRESTFields]] = None, query_string: Optional[str] = None, - incoming_cookie: dict[str, str] = {}, - outgoing_cookie: dict[str, str] = {}, ) -> None: """class to handle a request context, that will be kept and updated while walking url parts @@ -118,27 +120,29 @@ class RestRequest(Generic[_T_RestRequestParams]): self.url: str self.verb: rsrc_verb self.data: dict - self.raw_headers: list[Any] + self._raw_headers: list[Any] = [] + self._client: tuple[str, int] = () self.headers: dict[str, None | str | dict[str, None | str]] = {"host": None, "cookie": {}} self._saved_url_params: dict self.ReqParams: _T_RestRequestParams = type_request_params() self.url_stack: list[str] self._saved_url_stack: list[str] self.url_stack_index: int - self.incoming_cookie: dict[str, str] = incoming_cookie - self.outgoing_cookie: dict[str, str] = outgoing_cookie + self.outgoing_cookie: dict[str, str] = {} self.user: ACL_target_user = ACL_target_user_Annonymous() self.groups: list[ACL_target_group] = [] self.result: Optional[str] = None + self._forced_status: Optional[int] = None + self.root_resource: RestResourceBase = root_resource # = or create a fresh one = if url is None or verb is None or data is None: - raise RuntimeError("url and verb and data must be set") + raise RestResourceException("url and verb and data must be set") self.url = url self.verb = verb if data != {} and not check_type(data, T_AllSupportedFields): - raise RuntimeError(f"Wrong data type received: {data}") + raise RestResourceHandlerException_BadRequest(f"Wrong data type received: {data}") self.data = data @@ -157,13 +161,34 @@ class RestRequest(Generic[_T_RestRequestParams]): self._saved_url_stack = self.url_stack.copy() self.url_stack_index = 0 + def set_resp_status(self, status: int) -> None: + self._forced_status = status + + def get_root_resource(self) -> RestResourceBase: + return self.root_resource + + def get_status(self) -> int: + if self._forced_status is not None: + return self._forced_status + + if self.verb in (rsrc_verb.POST, rsrc_verb.PUT): + return 201 + + return 200 + + def set_client(self, client: tuple[str, int]) -> None: + self._client = client + + def get_client(self) -> tuple[str, int]: + return self._client + def set_headers(self, headers: list[Any]) -> None: - self.raw_headers = headers - for elem in self.raw_headers: + self._raw_headers = headers + for elem in self._raw_headers: if elem[0] == b"host": self.headers["host"] = elem[1].decode("utf-8") - # elif elem[0] == b"user-agent": - # self.headers["user-agent"] = elem[1].decode("utf-8") + elif elem[0] == b"user-agent": + self.headers["user-agent"] = elem[1].decode("utf-8") elif elem[0] == b"cookie": self.headers["cookie"] = parse_dict_cookies(elem[1].decode("utf-8")) @@ -172,8 +197,16 @@ class RestRequest(Generic[_T_RestRequestParams]): return None return self.headers["cookie"][key] + def set_resp_cookie_value(self, key: str, value: str) -> None: + self.outgoing_cookie[ + key + ] = f"{value}; expires={self.root_resource.get_new_cookie_expiration_date().strftime('%a, %d %b %Y %H:%M:%S GMT')}; path=/; HttpOnly" + + def reset_resp_cookie(self, key: str) -> None: + self.outgoing_cookie[key] = "null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT" + def get_host(self) -> str: - print(self.headers["host"]) + return self.headers["host"] def set_result(self, result: str): self.result = result diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index bd42c32..b7c00c0 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -8,9 +8,11 @@ from typing import ( from abc import ABC import json +import pprint + from pydantic import BaseModel -from .helpers import _JSONEncoder +from .helpers import _JSONEncoder, forward_exception from .rest_types import rsrc_verb, _T_SupportedRESTFields from .rest_ACL import ( @@ -22,6 +24,17 @@ from .rest_ACL import ( ) from .rest_request import RestRequest +from .rest_exceptions import ( + RestResourceLoginException_InvalidSession, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_ClientChange, + RestResourceLoginException_InvalidCredentials, + RestResourceHandlerException_ResourceNotFound, + RestResourceHandlerException_MethodNotAllowed, + RestResourceHandlerException_BadRequest, + RestResourceHandlerException_Forbiden, + RestResourceException, +) if TYPE_CHECKING is True: from .rest_types import ( @@ -64,15 +77,15 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if acl.rule is ACL_rule.ALLOW: # print("ALLOWED (user)") return - raise RuntimeError(f"Not allowed access detected: {field}") + raise RestResourceHandlerException_Forbiden(f"Not allowed access detected: {field}") elif isinstance(acl.target, ACL_target_group): if isinstance(acl.target, ACL_target_group_Any) or any(_ for _ in groups if _.name == acl.target.name): if acl.rule is ACL_rule.ALLOW: # print("ALLOWED (group)") return - raise RuntimeError(f"Not allowed access detected: {field}") + raise RestResourceHandlerException_Forbiden(f"Not allowed access detected: {field}") else: - raise RuntimeError(f"Wrong ACL target type: {field}") + raise RestResourceException(f"Wrong ACL target type: {field}") # print("ALLOWED (Default)") def check_acl_field(self, request: RestRequest, req_index: int = 0) -> None: @@ -89,7 +102,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if key in self.model_fields: self._check_acl(request.user, request.groups, rsrc_verb.PUT, key) else: - raise RuntimeError("Incompatible verb") + raise RestResourceException("Incompatible verb") def update(self, **new_data): for field, value in new_data.items(): @@ -119,27 +132,25 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if b"content-type" in scope["headers"]: assert scope["headers"][b"content-type"] == b"application/json" - # import pprint - - # print("----REC HEADER ---") - # pprint.pprint(scope["headers"]) + # pprint.pprint(scope) body = await self.read_body(receive) - verb = rsrc_verb[scope["method"]] request: RestRequest = self.process_request( - scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8"), scope["headers"] + scope["path"], + rsrc_verb[scope["method"]], + body.decode("utf-8"), + scope["query_string"].decode("utf-8"), + scope["client"], + scope["headers"], + True, ) - assert request != None - - status = 200 - if verb in (rsrc_verb.POST, rsrc_verb.PUT): - status = 201 + assert request is not None header_resp = { "type": "http.response.start", - "status": status, + "status": request.get_status(), "headers": [ [b"content-type", b"application/json"], ], @@ -148,8 +159,6 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): for name, value in request.outgoing_cookie.items(): header_resp["headers"].append(["Set-Cookie", f"{name}={value}"]) - # print("----SENT HEADER ---") - # pprint.pprint(header_resp) await send(header_resp) body = None @@ -172,7 +181,9 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): verb: rsrc_verb = rsrc_verb.GET, data_json: Optional[str] = None, query_string: Optional[str] = None, + client: Optional[tuple[str, int]] = None, headers: Optional[list[Any]] = None, + http_mode: bool = False, ) -> RestRequest: from .rest_resource_handler import ( ResourceHandler, @@ -188,22 +199,50 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): # preparing request & session request: RestRequest = ressource_handler.get_request() + assert request is not None + if headers is not None: request.set_headers(headers) + + if client is not None: + request.set_client(client) + + try: self._process_request_session(request) - # processing the verb - result = ressource_handler.process_verb() + result = ressource_handler.process_verb() - # print("OOO") - # print(type(self)._resp_cookies) - # print("OOO2") + if isinstance(result, RestResourceBase): + request.set_result(json.dumps(result.model_dump(mode="json"))) + elif result is not None: + request.set_result(json.dumps(result, cls=_JSONEncoder)) + else: + request.set_result("null") - if isinstance(result, RestResourceBase): - request.set_result(json.dumps(result.model_dump(mode="json"))) - elif result is not None: - request.set_result(json.dumps(result, cls=_JSONEncoder)) - else: - request.set_result("null") + except RestResourceHandlerException_ResourceNotFound as e: + request.set_resp_status(404) + forward_exception(e, not http_mode) + + except RestResourceHandlerException_MethodNotAllowed as e: + request.set_resp_status(405) + forward_exception(e, not http_mode) + + except RestResourceHandlerException_BadRequest as e: + request.set_resp_status(400) + forward_exception(e, not http_mode) + + except RestResourceHandlerException_Forbiden as e: + request.set_resp_status(403) + forward_exception(e, not http_mode) + + except ( + RestResourceLoginException_InvalidSession, + RestResourceLoginException_SessionTimeout, + RestResourceLoginException_ClientChange, + RestResourceLoginException_InvalidCredentials, + ) as e: + request.set_resp_status(401) + request.reset_resp_cookie("Authorization") + forward_exception(e, not http_mode) return request diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index 77cfd42..3197c53 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -33,6 +33,14 @@ from .rest_request_opt import ( _T_RestRequestParams_PUT, ) +from .rest_exceptions import ( + RestResourceHandlerException, + RestResourceHandlerException_ResourceNotFound, + RestResourceHandlerException_MethodNotAllowed, + RestResourceHandlerException_BadRequest, + RestResourceHandlerException_Forbiden, +) + if TYPE_CHECKING is True: from .rest_types import T_T_DictKey, T_T_DictValues from .rest_request import RestRequest @@ -83,7 +91,6 @@ class ResourceHandler( self.next_handler: Optional[ResourceHandler] = None self.saved_url: list[str] = [] self.resource: _T_Resource = resource - self.root_resource: _T_Resource = resource if prev_handler is None else prev_handler.root_resource self.req: RestRequest if prev_handler is not None: self.prev_handler = prev_handler @@ -91,13 +98,13 @@ class ResourceHandler( self._request_factory.update_RestRequest(self.req) elif None in [url, verb]: - raise RuntimeError("if req not set, url,verb must be setted") + raise RestResourceHandlerException("if req not set, url,verb must be setted") else: if url is None or verb is None: - raise RuntimeError("url and verb must be set") + raise RestResourceHandlerException("url and verb must be set") if data is None: data = {} - self.req = self._request_factory.get_RestRequest(url, verb, data, query_string) + self.req = self._request_factory.get_RestRequest(resource, url, verb, data, query_string) # print(f"[TRACE] creating {type(self).__name__}() with url={self.req.get_url_stack()}") @@ -116,7 +123,7 @@ class ResourceHandler( if resource_handler_cls._check_resource_handler(resource, req): # print(f"[DEBUG] match ResourceHandler: {resource_handler_cls.__name__}") return resource_handler_cls - raise RuntimeError(f"Unsupported Resource Type {type(resource).__name__}") + raise RestResourceHandlerException(f"Unsupported Resource Type {type(resource).__name__}") @classmethod def register_resource_handler(cls, other_cls) -> None: @@ -187,7 +194,7 @@ class ResourceHandler( return next_resource_handler # in _find_resource context, only resource's real values can be retrieved - raise RuntimeError("Wrong request") + raise RestResourceHandlerException_ResourceNotFound() def _check_access_rights(self): pass @@ -210,7 +217,7 @@ class ResourceHandler( self._process_delete() return None - raise RuntimeError("Invalid Verb") + raise RestResourceHandlerException_BadRequest("Invalid Verb") def _process_get( self, @@ -231,16 +238,16 @@ class ResourceHandler( self._handle_process_delete(self.req.get_req_params()) def _handle_process_get(self, params: _T_RestRequestParams_GET) -> _T_Resource | list[T_DictKey]: - raise RuntimeError(f"GET method not implemented for {type(self).__name__}") + raise RestResourceHandlerException_MethodNotAllowed(f"GET method not implemented for {type(self).__name__}") def _handle_process_put(self, params: _T_RestRequestParams_PUT) -> None: - raise RuntimeError(f"PUT method not implemented for {type(self).__name__}") + raise RestResourceHandlerException_MethodNotAllowed(f"PUT method not implemented for {type(self).__name__}") def _handle_process_post(self, params: _T_RestRequestParams_POST) -> Optional[T_DictKey]: - raise RuntimeError(f"POST method not implemented for {type(self).__name__}") + raise RestResourceHandlerException_MethodNotAllowed(f"POST method not implemented for {type(self).__name__}") def _handle_process_delete(self, params: _T_RestRequestParams_DELETE) -> None: - raise RuntimeError(f"DELETE method not implemented for {type(self).__name__}") + raise RestResourceHandlerException_MethodNotAllowed(f"DELETE method not implemented for {type(self).__name__}") @ResourceHandler.register_resource_handler @@ -289,8 +296,7 @@ class ResourceHandler_dict( # print(f"{type(self).__name__}->_handle_process_delete()") # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") - if self.prev_handler is None: - raise RuntimeError("Wrong command") + assert self.prev_handler is not None dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] @@ -308,8 +314,7 @@ class ResourceHandler_dict( # print(f"{type(self).__name__}->_handle_process_post()") # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") - if self.prev_handler is None: - raise RuntimeError("Wrong command") + assert self.prev_handler is not None dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] dict_value_type: T_T_DictValues = cast(RestResourceBase, self.prev_handler.resource)._dict_value_type_[ @@ -341,7 +346,9 @@ class ResourceHandler_dict( _dict[_obj_primary_key] = _obj return _obj_primary_key - RuntimeError("Either the object needs defined primary key or the request must contain an API_key param to process this command") + raise RestResourceHandlerException_BadRequest( + "Either the object needs defined primary key or the request must contain an API_key param to process this command" + ) return None # for mypy.... @@ -381,8 +388,7 @@ class ResourceHandler_dict_elem( # print(f"{type(self).__name__}->_process_get()") # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") - if self.prev_handler is None: - raise RuntimeError("Wrong command") + assert self.prev_handler is not None dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(1)] @@ -401,8 +407,7 @@ class ResourceHandler_dict_elem( # instead of expected get_resource_origin(1) because we need to go backward # because self.req is another context that is not saved to improve performances - if self.prev_handler is None: - raise RuntimeError("Wrong command") + assert self.prev_handler is not None dict_key_type: T_T_DictKey = cast(RestResourceBase, self.prev_handler.resource)._dict_key_type_[self.req.get_resource_origin(2)] @@ -460,13 +465,13 @@ class ResourceHandler_RestResourceBase( # print(self.resource.exclude) if self.req.get_resource_origin(0) not in self.resource.model_fields: - raise RuntimeError(f"Unknown field access detected: {self.req.get_url_stack()}") + raise RestResourceHandlerException_ResourceNotFound(f"Unknown field access detected: {self.req.get_url_stack()}") self.resource.check_acl_field(self.req) if len(self.req.get_url_stack()) == 0: # destination reached if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET: - raise RuntimeError(f"Not allowed READ access detected: {self.req.get_url_stack()}") + raise RestResourceHandlerException_ResourceNotFound(f"Not allowed READ access detected: {self.req.get_url_stack()}") def _handle_process_get(self, params) -> RestResourceBase: # print(f"{type(self).__name__}->_process_get()") @@ -480,13 +485,13 @@ class ResourceHandler_RestResourceBase( if key in self.resource._plugins_: if issubclass(self.resource._plugins_[key], ResourcePlugin_field): plugin_field: ResourcePlugin_field = cast( - ResourcePlugin_field, self.resource._plugins_[key](self.req, self.root_resource) + ResourcePlugin_field, self.resource._plugins_[key](self.req, self.req.get_root_resource()) ) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_field_get(value, params)) elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): plugin_field: ResourcePlugin_field = cast( - ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource) + ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.req.get_root_resource()) ) value = getattr(self.resource, key) setattr(self.resource, key, plugin_field.handle_resource_get(value, params)) @@ -509,14 +514,14 @@ class ResourceHandler_RestResourceBase( if issubclass(self.resource._plugins_[key], ResourcePlugin_field): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.resource._plugins_[key](self.req, self.root_resource), + self.resource._plugins_[key](self.req, self.req.get_root_resource()), ) value = plugin_rsrc.handle_field_get(value, params) elif issubclass(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.resource._plugins_[key](self.req, self.root_resource), + self.resource._plugins_[key](self.req, self.req.get_root_resource()), ) value = plugin_rsrc.handle_resource_get(value, params) @@ -539,7 +544,7 @@ class ResourceHandler_RestResourceBase( if key in _new_resrc._plugins_: if issubclass(_new_resrc._plugins_[key], ResourcePlugin_field): plugin_field: ResourcePlugin_field = cast( - ResourcePlugin_field, _new_resrc._plugins_[key](self.req, self.root_resource) + ResourcePlugin_field, _new_resrc._plugins_[key](self.req, self.req.get_root_resource()) ) value = getattr(_new_resrc, key) setattr(_new_resrc, key, plugin_field.handle_field_put(value, params)) @@ -555,7 +560,7 @@ class ResourceHandler_RestResourceBase( if key in self.prev_handler.prev_handler.resource._plugins_: plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.prev_handler.prev_handler.resource._plugins_[key](self.req, self.root_resource), + self.prev_handler.prev_handler.resource._plugins_[key](self.req, self.req.get_root_resource()), ) _new_resrc = plugin_rsrc.handle_dict_elem_put(_new_resrc, params) # element is within a RestResourceBase @@ -564,7 +569,7 @@ class ResourceHandler_RestResourceBase( if key in self.prev_handler.resource._plugins_: plugin_rsrc: ResourcePlugin_RestResourceBase = cast( ResourcePlugin_RestResourceBase, - self.prev_handler.resource._plugins_[key](self.req, self.root_resource), + self.prev_handler.resource._plugins_[key](self.req, self.req.get_root_resource()), ) _new_resrc = plugin_rsrc.handle_resource_put(_new_resrc, params) @@ -584,7 +589,7 @@ class ResourceHandler_RestResourceBase( ): self.prev_handler._process_delete() else: - raise RuntimeError("cannot delete an element outside a dict") + raise RestResourceHandlerException_BadRequest("cannot delete an element outside a dict") @ResourceHandler.register_resource_handler @@ -615,7 +620,7 @@ class ResourceHandler_simple( if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, - self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource), + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.req.get_root_resource()), ) return plugin_simple.handle_field_get(self.resource, params) @@ -636,7 +641,7 @@ class ResourceHandler_simple( # print("PLUGIN FOUND") plugin_simple: ResourcePlugin_field = cast( ResourcePlugin_field, - self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource), + self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.req.get_root_resource()), ) # print(value) value = plugin_simple.handle_field_put(value, params) diff --git a/src/pyrestresource/rest_resource_plugin.py b/src/pyrestresource/rest_resource_plugin.py index 4e7bcea..6a97876 100644 --- a/src/pyrestresource/rest_resource_plugin.py +++ b/src/pyrestresource/rest_resource_plugin.py @@ -2,16 +2,17 @@ from __future__ import annotations from typing import Optional, Generic, TYPE_CHECKING from abc import abstractmethod, ABC +from datetime import datetime from .rest_types import ( _T_DictValues, _T_DictKey, TV_SupportedRESTFields, TV_RestResourceBase, - RestResourceException, ) from .rest_request import RestRequest + if TYPE_CHECKING is True: from .rest_resource import RestResourceBase from .rest_request_opt import ( @@ -27,14 +28,6 @@ if TYPE_CHECKING is True: ) -class RestResourcePluginException(RestResourceException): - pass - - -class RestResourcePluginException_InvalidPluginSignature(RestResourcePluginException): - pass - - class ResourcePlugin(ABC): def __init__(self, request: RestRequest, root_resource: RestResourceBase) -> None: self.__request: RestRequest = request @@ -46,16 +39,17 @@ class ResourcePlugin(ABC): def get_user_login(self) -> str: return self.__request.get_user().name - def getr_req_cookie_value(self, key: str) -> Optional[str]: - return self.__request.incoming_cookie[key] + def set_resp_cookie_value(self, key: str, value: str) -> None: + self.__request.set_resp_cookie_value(key, value) - def set_resp_cookie_value(self, key: str, value: str): - # print("AAA") - # print(name) - # print(value) - # print(self.cookies) - # print(type(self.cookies)) - self.__request.outgoing_cookie[key] = value + def reset_resp_cookie(self, key: str) -> None: + self.__request.reset_resp_cookie(key) + + def get_new_cookie_expiration_date(self) -> datetime: + return self.__root_resource.get_new_cookie_expiration_date() + + def set_resp_status(self, status: int) -> None: + self.__request.set_resp_status(status) class ResourcePlugin_field(ResourcePlugin, Generic[TV_SupportedRESTFields]): diff --git a/src/pyrestresource/rest_resource_rootpoint.py b/src/pyrestresource/rest_resource_rootpoint.py index a13e521..ac62b58 100644 --- a/src/pyrestresource/rest_resource_rootpoint.py +++ b/src/pyrestresource/rest_resource_rootpoint.py @@ -12,7 +12,6 @@ from .rest_resource_plugin import ( ResourcePlugin_field, ResourcePlugin_RestResourceBase, ResourcePlugin_dict, - RestResourcePluginException_InvalidPluginSignature, ) from .rest_resource_walker import ( RestResourceWalker_Root, @@ -26,6 +25,7 @@ from .rest_ACL import ( ACL_target_group_Any, ACL_rule, ) +from .rest_exceptions import RestResourcePluginException_InvalidPluginSignature, RestResourceModelException, RestResourceModelException_ACL if TYPE_CHECKING is True: pass @@ -37,9 +37,9 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): # checking compatibility if not get_origin(datatype[1]) is None: - raise RuntimeError("complex dict types are not supported (should create a RestResourceBase container)") + raise RestResourceModelException("complex dict types are not supported (should create a RestResourceBase container)") if not datatype[0] in _T_SupportedRESTFields: - raise RuntimeError(f"Unsupported Dict Field value type in class (key)") + raise RestResourceModelException(f"Unsupported Dict Field value type in class (key)") # preprocessing types / structure if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): @@ -69,10 +69,10 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): # print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}") self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] else: - raise RuntimeError("ACL must be a list()") + raise RestResourceModelException_ACL("ACL must be a list()") else: - raise RuntimeError("dict must be contained in a RestResourceBase") + raise RestResourceModelException("dict must be contained in a RestResourceBase") class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFields): @@ -96,7 +96,9 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi if "primary_key" in self.resource.json_schema_extra and self.resource.json_schema_extra["primary_key"] is True: if self.parent.annotation._primary_key_ is not None: - raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}") + raise RestResourceModelException( + f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}" + ) self.parent.annotation._primary_key_ = self.resource_name self.parent.annotation._ACL_record_[self.resource_name] = [ ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY) @@ -114,10 +116,10 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi # print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}") self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] else: - raise RuntimeError("ACL must be a list()") + raise RestResourceModelException_ACL("ACL must be a list()") else: - raise RuntimeError("fields must be contained in a RestResourceBase") + raise RestResourceModelException("fields must be contained in a RestResourceBase") class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_RestResourceBase): @@ -153,7 +155,7 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_ # print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}") self.parent.annotation._ACL_record_[self.resource_name] += self.resource.json_schema_extra["ACL"] else: - raise RuntimeError("ACL must be a list()") + raise RestResourceModelException_ACL("ACL must be a list()") class RestResourceWalker_Root__tree_init(RestResourceWalker_Root): diff --git a/src/pyrestresource/rest_resource_walker.py b/src/pyrestresource/rest_resource_walker.py index a72a834..f3d14fe 100644 --- a/src/pyrestresource/rest_resource_walker.py +++ b/src/pyrestresource/rest_resource_walker.py @@ -15,6 +15,7 @@ from pydantic.fields import FieldInfo from .rest_types import _T_SupportedRESTFields from .rest_resource import RestResourceBase +from .rest_exceptions import RestResourceModelException if TYPE_CHECKING is True: from typing import Any, Optional @@ -59,7 +60,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): if _is_valid is True: return sub(resource_name, resource, parent, _anno, _optional, argument) - raise RuntimeError(f"Incompatible Field Found: {type(resource).__name__}") + raise RestResourceModelException(f"Incompatible Field Found: {type(resource).__name__}") return None def __init__( @@ -91,35 +92,10 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): self.optional = _optional if self.annotation is None: - raise RuntimeError("Only annotated types are allowed in RestResourceBase derived classes") + raise RestResourceModelException("Only annotated types are allowed in RestResourceBase derived classes") self.subdatatype = get_args(self.annotation) - """ - def info(self) -> None: - print(f"{type(self).__name__}->info()") - print("==========================") - print(f"resource_name: {self.resource_name}") - print(f"resource: {type(self.resource).__name__}") - print(f"resource: {self.resource}") - print(f"parent: {self.parent}") - print(f"annotation: {self.annotation}") - print(f"optional: {self.optional}") - print(f"subdatatype: {self.subdatatype}") - - # -> cannot do that on dicts - # if self.parent is not None: - # print(f"_model_dump_excluded_: {self.parent.annotation._model_dump_excluded_}") - - if False: - print("------ STACK ------") - _rsrc = self.parent - while _rsrc is not None: - print(f"{id(_rsrc.annotation)}:{_rsrc.annotation}") - _rsrc = _rsrc.parent - print("-------------------") - """ - @abstractmethod def get_future(self) -> Optional[RestResourceWalkerFutureResult]: return self.future_result @@ -163,7 +139,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): elif not isinstance(resource, FieldInfo) and issubclass(resource, RestResourceBase): _anno = resource else: - raise RuntimeError("Incompatible resource type") + raise RestResourceModelException("Incompatible resource type") _datatype = get_args(_anno) _optional: bool = False @@ -176,7 +152,7 @@ class RestResourceWalker_Sub(ABC, Generic[TV_RestResourceWalkerFutureResult]): _anno = _datatype[0] _optional = True else: - raise RuntimeError("Union is only allowed to describe Optional (e.g. Union[XXX,None])") + raise RestResourceModelException("Union is only allowed to describe Optional (e.g. Union[XXX,None])") return _anno, _optional @@ -277,5 +253,5 @@ class RestResourceWalker_Root: current_deep = current_deep + 1 return sub_walker_initial.chain_process_future() else: - raise RuntimeError("Invalid Rootpoint") + raise RestResourceModelException("Invalid Rootpoint") return None diff --git a/src/pyrestresource/rest_types.py b/src/pyrestresource/rest_types.py index 6e0a7f8..da34850 100644 --- a/src/pyrestresource/rest_types.py +++ b/src/pyrestresource/rest_types.py @@ -12,10 +12,6 @@ if TYPE_CHECKING is True: pass -class RestResourceException(Exception): - pass - - T_Gen_DictKeys: type = type({}.keys()) NoneType = type(None) @@ -63,8 +59,7 @@ TV_SupportedRESTFields = TypeVar( NoneType, ) -if get_origin(T_SupportedRESTFields) is not Union: - raise RuntimeError("wrong T_SupportedRESTFields (must be flat Union)") +assert get_origin(T_SupportedRESTFields) is Union TV_RestResourceBase = TypeVar("TV_RestResourceBase", bound="RestResourceBase") diff --git a/test/test_ACL.py b/test/test_ACL.py index 5e69df8..7847676 100644 --- a/test/test_ACL.py +++ b/test/test_ACL.py @@ -5,12 +5,8 @@ from pathlib import Path from typing import Optional from pydantic import Field - -print(__name__) -print(__package__) - - from src.pyrestresource import ( + RestResourceHandlerException_Forbiden, register_rest_rootpoint, RestResourceBase, rsrc_verb, @@ -85,11 +81,11 @@ class Test_RestAPI_ACL(unittest.TestCase): result = self.testapp.process_request("/resource_ro", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_ro/version_ro", rsrc_verb.PUT, '"6.6.6"') self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3") - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_ro", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}') self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3") @@ -107,7 +103,7 @@ class Test_RestAPI_ACL(unittest.TestCase): self.assertEqual(result.get_result(), "null") self.assertEqual(self.testapp.resource_with_secret.username, None) - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) self.assertEqual(self.testapp.resource_with_secret.secret, None) @@ -122,7 +118,7 @@ class Test_RestAPI_ACL(unittest.TestCase): self.assertEqual(result.get_result(), '"chacha"') self.assertEqual(self.testapp.resource_with_secret.username, "chacha") - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) self.assertEqual(self.testapp.resource_with_secret.secret, "123456") @@ -138,13 +134,13 @@ class Test_RestAPI_ACL(unittest.TestCase): self.assertEqual(result.get_result(), '"chacha"') self.assertEqual(self.testapp.resource_with_secret.username, "chacha") - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) result = self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.PUT, '"123456"') self.assertEqual(result.get_result(), "null") - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret/secret", rsrc_verb.GET) self.assertEqual(self.testapp.resource_with_secret.secret, "123456") @@ -160,23 +156,23 @@ class Test_RestAPI_ACL(unittest.TestCase): self.assertEqual(result.get_result(), "null") self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret_ACL/secret", rsrc_verb.GET) self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret_ACL", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) def test_subresource_ACL_field(self): - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret_ACL/username", rsrc_verb.PUT, '"chacha"') self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/resource_with_secret_ACL/secret", rsrc_verb.PUT, '"123456"') self.assertEqual(self.testapp.resource_with_secret_ACL.username, None) self.assertEqual(self.testapp.resource_with_secret_ACL.secret, None) diff --git a/test/test_rest_login.py b/test/test_rest_login.py index f3fb471..9107fbc 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -1,12 +1,10 @@ from __future__ import annotations import unittest -from unittest.mock import patch from os import chdir from pathlib import Path -from typing import Optional, Annotated, ClassVar +from typing import Optional, ClassVar from pydantic import Field -from uuid import UUID, uuid4 -from time import time, sleep +from time import sleep import uvicorn import socket import requests @@ -14,10 +12,6 @@ from contextlib import closing from multiprocessing import Process from requests.adapters import HTTPAdapter -print(__name__) -print(__package__) - - from src.pyrestresource import ( ACL_target_user, UserLogin, @@ -45,6 +39,7 @@ chdir(testdir_path.parent.resolve()) # to allow mock-ing, all the tested classes are in a function def init_classes(): user_test = UserLogin(username="TestUser", secret="123456") + user_test2 = UserLogin(username="TestUser2", secret="abcdef") class TestResource(RestResourceBase): test_field: Optional[str] = Field("ORIGIN_VALUE") @@ -57,10 +52,25 @@ def init_classes(): ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), ], ) + test_field2: Optional[str] = Field( + "ORIGIN_VALUE", + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test2), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) + test_field_both: Optional[str] = Field( + "ORIGIN_VALUE", + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user.from_user_login(user_test2), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) @register_rest_rootpoint class RootApp(RestResourceBaseLogin): - _ar_user_login: ClassVar[list[UserLogin]] = [user_test] + _ar_user_login: ClassVar[list[UserLogin]] = [user_test, user_test2] test_resourceACL: TestResource = Field( TestResource(), ACL=[ @@ -92,6 +102,113 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): def setUp(self) -> None: chdir(testdir_path.parent.resolve()) + def test_login_two_users(self): + ip, port = find_free_port() + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # login + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + # unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field2", json="TEST SET VALUE") + self.assertEqual(response.status_code, 403) + + # not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field2", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "ORIGIN_VALUE") + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field_both", json="TEST SET VALUE 2") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field_both", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") + + # --------------------------------------- + # login 2 + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser2", "secret": "abcdef"}, + ) + self.assertEqual(response.status_code, 201) + + # unauthenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="A TEST SET VALUE") + self.assertEqual(response.status_code, 403) + + # not modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE") + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field2", json="A TEST SET VALUE") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field2", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "A TEST SET VALUE") + + # previous (modified) value + response = s.get( + f"http://{ip}:{port}/test_resource/test_field_both", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TEST SET VALUE 2") + + # authenticated write (to field) + response = s.put(f"http://{ip}:{port}/test_resource/test_field_both", json="A TEST SET VALUE 2") + self.assertEqual(response.status_code, 201) + + # modified + response = s.get( + f"http://{ip}:{port}/test_resource/test_field_both", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "A TEST SET VALUE 2") + + finally: + proc.terminate() + s.close() + def test_login(self): ip, port = find_free_port() proc = Process( @@ -146,6 +263,239 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): proc.terminate() s.close() + def test_change_host(self): + ip, port = find_free_port() + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s1 = requests.Session() + s1.mount("http://", HTTPAdapter(max_retries=0)) + s2 = requests.Session() + s2.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # s1 - read full login resource + response = s1.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s1 - read login username field + response = s1.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # s2 - read full login resource + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s2 - read login username field + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # login s1 + response = s1.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + + # s1 - read full login resource + response = s1.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "TestUser"}) + + # s1 - read login username field + response = s1.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TestUser") + + # s2 - read full login resource + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s2 - read login username field + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + # s2 -> spoof s1 token + s2.cookies.update(s1.cookies) + + # s2 - read full login resource + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s2.cookies.get_dict(), {}) + + # s2 - read full login resource (reseted cookie) + response = s2.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + + # s2 -> spoof s1 token + s2.cookies.update(s1.cookies) + + # s2 - read login username field + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s2.cookies.get_dict(), {}) + + # s2 - read full login resource (reseted cookie) + response = s2.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + + finally: + proc.terminate() + s1.close() + s2.close() + + def test_login_wrong_pwd(self): + ip, port = find_free_port() + proc = Process( + target=launch_server, + args=( + ip, + port, + ), + ) + proc.start() + sleep(1) + s = requests.Session() + s.mount("http://", HTTPAdapter(max_retries=0)) + + try: + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + self.assertDictEqual(s.cookies.get_dict(), {}) + + # --------------------------------------------------- + # login (wrong pwd) + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "abc"}, + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + self.assertDictEqual(s.cookies.get_dict(), {}) + + # --------------------------------------------------- + # login (ok pwd) + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "123456"}, + ) + self.assertEqual(response.status_code, 201) + self.assertTrue("Authorization" in response.cookies) + self.assertTrue("Authorization" in s.cookies.get_dict()) + self.assertTrue(s.cookies.get_dict()["Authorization"]) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "TestUser"}) + self.assertTrue("Authorization" in s.cookies.get_dict()) + self.assertTrue(s.cookies.get_dict()["Authorization"]) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "TestUser") + self.assertTrue("Authorization" in s.cookies.get_dict()) + self.assertTrue(s.cookies.get_dict()["Authorization"]) + + # --------------------------------------------------- + # login (wrong pwd, after success) + response = s.put( + f"http://{ip}:{port}/login", + json={"username": "TestUser", "secret": "abc"}, + ) + self.assertEqual(response.status_code, 401) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read full login resource + response = s.get( + f"http://{ip}:{port}/login", + ) + self.assertEqual(response.status_code, 200) + self.assertDictEqual(response.json(), {"username": "__ANNONYMOUS__"}) + self.assertDictEqual(s.cookies.get_dict(), {}) + + # read login username field + response = s.get( + f"http://{ip}:{port}/login/username", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), "__ANNONYMOUS__") + self.assertDictEqual(s.cookies.get_dict(), {}) + + finally: + proc.terminate() + s.close() + def test_access_resourceACL(self): ip, port = find_free_port() proc = Process( @@ -170,7 +520,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): # try unauthenticated write (to field) response = s.put(f"http://{ip}:{port}/test_resourceACL/test_field", json="TEST SET VALUE") - self.assertEqual(response.status_code, 500) + self.assertEqual(response.status_code, 403) # check not modified response = s.get( @@ -181,7 +531,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): # try unauthenticated write (to resource) response = s.put(f"http://{ip}:{port}/test_resourceACL", json={"test_field": "TEST SET VALUE"}) - self.assertEqual(response.status_code, 500) + self.assertEqual(response.status_code, 403) # check not modified response = s.get( @@ -247,7 +597,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): # try unauthenticated write (to field) response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE") - self.assertEqual(response.status_code, 500) + self.assertEqual(response.status_code, 403) # check not modified response = s.get( @@ -258,7 +608,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase): # try unauthenticated write (to resource) response = s.put(f"http://{ip}:{port}/test_resource", json={"test_field": "TEST SET VALUE"}) - self.assertEqual(response.status_code, 500) + self.assertEqual(response.status_code, 403) # check not modified response = s.get( diff --git a/test/test_rest_resource.py b/test/test_rest_resource.py index f3b7c9d..2fd3d97 100644 --- a/test/test_rest_resource.py +++ b/test/test_rest_resource.py @@ -14,6 +14,7 @@ print(__name__) print(__package__) from src.pyrestresource import ( + RestResourceHandlerException_Forbiden, register_rest_rootpoint, RestResourceBase, rsrc_verb, @@ -268,11 +269,11 @@ class Test_RestAPI_GET(unittest.TestCase): self.assertEqual(result.get_result(), '"chacha"') def test_get_dict_user_element__nested_value__forbiden(self): - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request("/users/8da57a3c-661f-11ee-8c99-0242ac120002/secret", rsrc_verb.GET) def test_get_dict_user_element__nested_value__forbiden2(self): - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request( "/users/8da57a3c-661f-11ee-8c99-0242ac120002/secret?API_nested=True", rsrc_verb.GET, @@ -302,7 +303,7 @@ class Test_RestAPI_PUT(unittest.TestCase): self.assertEqual(result.get_result(), '"chacha2"') def test_put_user_nested_value__forbiden(self): - with self.assertRaises(RuntimeError): # TODO: custom exception + with self.assertRaises(RestResourceHandlerException_Forbiden): # TODO: custom exception self.testapp.process_request( "/users/8da57a3c-661f-11ee-8c99-0242ac120002/uuid", rsrc_verb.PUT, diff --git a/test/test_rest_resource_plugins.py b/test/test_rest_resource_plugins.py index e797b13..358ee55 100644 --- a/test/test_rest_resource_plugins.py +++ b/test/test_rest_resource_plugins.py @@ -35,7 +35,6 @@ def init_classes(): class ResourcePlugin_Info(ResourcePlugin_RestResourceBase_default): def handle_resource_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get: - print("HOOK GET !!") return Info_get(version="65.45", api_version="98.321") class Info_get(RestResourceBase): @@ -95,9 +94,10 @@ class Test_RestAPI_Plugin_PUT(unittest.TestCase): self.testapp.process_request("/info_put/version", rsrc_verb.PUT, '"1.5.6"') result = self.testapp.process_request("/info_put", rsrc_verb.GET) - print(result.get_result()) + self.assertEqual(result.get_result(), '{"version": "42", "api_version": "0.0.2"}') + result = self.testapp.process_request("/info_put/version", rsrc_verb.GET) - print(result.get_result()) + self.assertEqual(result.get_result(), '"42"') def test_put_field_version_resourceplugin(self): diff --git a/test/test_rest_webserver.py b/test/test_rest_webserver.py index 9329214..5c86163 100644 --- a/test/test_rest_webserver.py +++ b/test/test_rest_webserver.py @@ -154,15 +154,6 @@ class Test_RestAPI_WebServer(unittest.TestCase): ["9b0381d4-65f6-11ee-8c99-0242ac120002"], ) - # Login in - """ - response = s.post( - f"http://{ip}:{port}/login", - params={"username": "test", "password": "test"}, - ) - self.assertEqual(response.status_code, 200) - """ - # Add a new one (with all values setted) response = s.post( f"http://{ip}:{port}/games",