From c3ff00e877ab69e62e8d18a35ed949f170fa2c6b Mon Sep 17 00:00:00 2001 From: cclecle Date: Fri, 3 Nov 2023 13:11:38 +0000 Subject: [PATCH] implement first ACL version (lack dict support) --- src/pyrestresource/__init__.py | 1 + src/pyrestresource/rest_ACL.py | 41 ++++++ src/pyrestresource/rest_request.py | 17 +++ src/pyrestresource/rest_resource.py | 146 +++++++++++++++----- src/pyrestresource/rest_resource_handler.py | 44 ++++-- src/pyrestresource/rest_types.py | 15 +- test/test_rest_login.py | 47 +++++-- 7 files changed, 250 insertions(+), 61 deletions(-) create mode 100644 src/pyrestresource/rest_ACL.py diff --git a/src/pyrestresource/__init__.py b/src/pyrestresource/__init__.py index 6ee712e..15d39df 100644 --- a/src/pyrestresource/__init__.py +++ b/src/pyrestresource/__init__.py @@ -53,3 +53,4 @@ from .rest_resource_plugin import ( ResourcePlugin_RestResourceBase_default, ResourcePlugin_dict_default, ) +from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule diff --git a/src/pyrestresource/rest_ACL.py b/src/pyrestresource/rest_ACL.py new file mode 100644 index 0000000..a75fa8b --- /dev/null +++ b/src/pyrestresource/rest_ACL.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from pydantic import BaseModel +from enum import Enum, auto + +from .rest_types import rsrc_verb + + +class ACL_target(BaseModel): + pass + + +class ACL_target_user(ACL_target): + name: str + + +class ACL_target_user_Annonymous(ACL_target): + name: str = "__ANNONYMOUS__" + + +class ACL_target_group(ACL_target): + name: str + + +class ACL_target_group_Annonymous(ACL_target): + name: str = "__ANNONYMOUS__" + + +class ACL_target_group_Any(ACL_target_group): + name: str = "__ANY__" + + +class ACL_rule(Enum): + ALLOW = auto() + DENY = auto() + + +class ACL_record(BaseModel): + verbs: list[rsrc_verb] + target: ACL_target + rule: ACL_rule diff --git a/src/pyrestresource/rest_request.py b/src/pyrestresource/rest_request.py index 23ff34e..80f223d 100644 --- a/src/pyrestresource/rest_request.py +++ b/src/pyrestresource/rest_request.py @@ -25,6 +25,8 @@ from .rest_request_opt import ( _T_RestRequestParams_PUT, ) +from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group, ACL_target_group_Annonymous + class RequestFactory( Generic[ @@ -118,6 +120,9 @@ class RestRequest(Generic[_T_RestRequestParams]): self.url_stack_index: int self.incoming_cookie: dict[str, str] = incoming_cookie self.outgoing_cookie: dict[str, str] = outgoing_cookie + self.user: ACL_target_user = ACL_target_user_Annonymous() + self.group: ACL_target_group = ACL_target_group_Annonymous() + self.result: Optional[str] = None # = or create a fresh one = if url is None or verb is None or data is None: @@ -141,6 +146,18 @@ class RestRequest(Generic[_T_RestRequestParams]): self._saved_url_stack = self.url_stack.copy() self.url_stack_index = 0 + def set_result(self, result: str): + self.result = result + + def get_result(self) -> Optional[str]: + return self.result + + def set_user(self, user: ACL_target_user): + self.user: ACL_target_user = user + + def set_group(self, group: ACL_target_group): + self.group: ACL_target_group = group + def update_ReqParams(self, type_request_params: type[_T_RestRequestParams]): self.ReqParams = type_request_params(**self._saved_url_params) diff --git a/src/pyrestresource/rest_resource.py b/src/pyrestresource/rest_resource.py index 4ba860c..8a70a0e 100644 --- a/src/pyrestresource/rest_resource.py +++ b/src/pyrestresource/rest_resource.py @@ -34,6 +34,16 @@ from .rest_resource_plugin import ( ResourcePlugin_dict, ) +from .rest_ACL import ( + ACL_record, + ACL_target_user, + ACL_target_group, + ACL_target_user_Annonymous, + ACL_target_group_Annonymous, + ACL_target_group_Any, + ACL_rule, +) + from .rest_resource_walker import ( RestResourceWalkerFutureResult, @@ -116,17 +126,26 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict): self.resource.exclude = True self.parent.resource.model_rebuild(force=True) + self.parent.annotation._ACL_record_[self.resource_name] = [] + if ( isinstance(self.resource, FieldInfo) and self.resource.json_schema_extra is not None and type(self.resource.json_schema_extra) is dict - and "plugin" in self.resource.json_schema_extra ): - plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"] - if not isinstance(plugin_dict, ResourcePlugin_dict): - raise RuntimeError("Wrong plugin signature provided") - self.parent.annotation._plugins_[self.resource_name] = plugin_dict - # print("ADD DICT PLUGIN") + if "plugin" in self.resource.json_schema_extra: + plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"] + if not isinstance(plugin_dict, ResourcePlugin_dict): + raise RuntimeError("Wrong plugin signature provided") + self.parent.annotation._plugins_[self.resource_name] = plugin_dict + # print("ADD DICT PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") else: raise RuntimeError("dict must be contained in a RestResourceBase") @@ -141,6 +160,9 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi # pprint.pprint(self.resource.json_schema_extra) # pprint.pprint(self.annotation) # pprint.pprint(self.resource.exclude) + + self.parent.annotation._ACL_record_[self.resource_name] = [] + if ( isinstance(self.resource, FieldInfo) and self.resource.json_schema_extra is not None @@ -153,13 +175,20 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}") self.parent.annotation._primary_key_ = self.resource_name - if "plugin" in self.resource.json_schema_extra and self.resource.json_schema_extra["plugin"]: + if "plugin" in self.resource.json_schema_extra: plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"] if not isinstance(plugin_field, ResourcePlugin_field): raise RuntimeError("Wrong plugin signature provided") self.parent.annotation._plugins_[self.resource_name] = plugin_field # print("ADD FIELD PLUGIN") + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") + else: raise RuntimeError("fields must be contained in a RestResourceBase") @@ -171,24 +200,33 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_ setattr(self.annotation, "_model_dump_excluded_", {}) setattr(self.annotation, "_primary_key_", None) setattr(self.annotation, "_plugins_", {}) + setattr(self.annotation, "_ACL_record_", {}) # preprocessing types / structure if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase): self.parent.annotation._model_dump_excluded_[self.resource_name] = True self.resource.exclude = True self.parent.resource.model_rebuild(force=True) + self.parent.annotation._ACL_record_[self.resource_name] = [] if ( isinstance(self.resource, FieldInfo) and self.resource.json_schema_extra is not None and type(self.resource.json_schema_extra) is dict - and "plugin" in self.resource.json_schema_extra ): - plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] - if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase): - raise RuntimeError("Wrong plugin signature provided") - self.parent.annotation._plugins_[self.resource_name] = plugin_resource - # print("ADD RESOURCE PLUGIN") + if "plugin" in self.resource.json_schema_extra: + plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] + if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase): + raise RuntimeError("Wrong plugin signature provided") + self.parent.annotation._plugins_[self.resource_name] = plugin_resource + # print("ADD RESOURCE PLUGIN") + + if "ACL" in self.resource.json_schema_extra: + if isinstance(self.resource.json_schema_extra["ACL"], list): + print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}") + self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"] + else: + raise RuntimeError("ACL must be a list()") class RestResourceWalker_Root__tree_init(RestResourceWalker_Root): @@ -213,10 +251,55 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): _plugins_: ClassVar[ dict[ str, - ResourcePlugin_field | ResourcePlugin_RestResourceBase | ResourcePlugin_dict, + list[ACL_record], ] ] = {} - _request: Optional[RestRequest] = None + _ACL_record_: ClassVar[ + dict[ + str, + ACL_record, + ] + ] = {} + + def _check_acl(self, user: ACL_target_user, group: ACL_target_group, verb: rsrc_verb, field: str): + print(f"evaluate self ACLs rule: {self._ACL_record_}") + if verb is rsrc_verb.GET and self.model_fields[field].exclude is True: + print("ALLOWED (excluded field)") + return + for acl in self._ACL_record_[field]: + print(f"evaluate ACL rule: {acl}") + if verb in acl.verbs: + if isinstance(acl.target, ACL_target_user): + if user == acl.target: + if acl.rule is ACL_rule.ALLOW: + print("ALLOWED (user)") + return + raise RuntimeError(f"Not allowed access detected: {field}") + elif isinstance(acl.target, ACL_target_group): + if group == acl.target or acl.target == ACL_target_group_Any(): + if acl.rule is ACL_rule.ALLOW: + print("ALLOWED (group)") + return + raise RuntimeError(f"Not allowed access detected: {field}") + else: + raise RuntimeError(f"Wrong ACL target type: {field}") + print("ALLOWED (Default)") + + def check_acl_access(self, request: RestRequest) -> None: + """Check ACL on requested field access""" + self._check_acl(request.user, request.group, request.get_verb(), request.get_resource_origin(0)) + + def check_acl_operation(self, request: RestRequest, new_data: Optional[dict[str, _T_SupportedRESTFields]]) -> None: + """Check ACL on requested field operation (involving checking sub-fields)""" + if request.get_verb() is rsrc_verb.GET: + for key in self.model_fields.keys(): + self._check_acl(request.user, request.group, rsrc_verb.GET, key) + elif request.get_verb() is rsrc_verb.PUT: + for key in new_data.keys(): + if key in self.model_fields: + self._check_acl(request.user, request.group, rsrc_verb.PUT, key) + else: + raise RuntimeError("Incompatible verb") def update(self, **new_data): for field, value in new_data.items(): @@ -251,12 +334,11 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): body = await self.read_body(receive) verb = rsrc_verb[scope["method"]] - self._request = None - result = self.process_request( + request: RestRequest = self.process_request( scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8") ) - assert self._request != None + assert request != None status = 200 if verb in (rsrc_verb.POST, rsrc_verb.PUT): @@ -270,7 +352,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): ], } - for name, value in self._request.outgoing_cookie.items(): + for name, value in request.outgoing_cookie.items(): header_resp["headers"].append(["Set-Cookie", f"{name}={value}"]) # print("----SENT HEADER ---") @@ -278,8 +360,8 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): await send(header_resp) body = None - if result: - body = result.encode("utf-8") + if request.get_result(): + body = request.get_result().encode("utf-8") await send( { @@ -294,7 +376,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): verb: rsrc_verb = rsrc_verb.GET, data_json: Optional[str] = None, query_string: Optional[str] = None, - ) -> Optional[str]: + ) -> RestRequest: from .rest_resource_handler import ( ResourceHandler, ResourceHandler_RestResourceBase, @@ -304,22 +386,20 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True): if data_json: data = json.loads(data_json) - ressource: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string) - self._request = ressource.get_request() + ressource_handler: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string) - result = ressource.process_verb() + request: RestRequest = ressource_handler.get_request() + assert request != None + + result = ressource_handler.process_verb() # print("OOO") # print(type(self)._resp_cookies) # print("OOO2") if isinstance(result, RestResourceBase): - # exclude: Optional[dict[str, bool]] = None - # raw_exclude = RestResourceWalker_Root__tree_exclude(result).process() - # exclude = next(iter(raw_exclude.values())) - # return json.dumps(result.model_dump(mode="json", exclude=exclude)) - return json.dumps(result.model_dump(mode="json")) + request.set_result(json.dumps(result.model_dump(mode="json"))) + elif result is not None: + request.set_result(json.dumps(result, cls=_JSONEncoder)) - if result is not None: - return json.dumps(result, cls=_JSONEncoder) - return None + return request diff --git a/src/pyrestresource/rest_resource_handler.py b/src/pyrestresource/rest_resource_handler.py index 218f412..5996cc6 100644 --- a/src/pyrestresource/rest_resource_handler.py +++ b/src/pyrestresource/rest_resource_handler.py @@ -19,6 +19,15 @@ from .rest_resource_plugin import ( ResourcePlugin_RestResourceBase, ) +from .rest_ACL import ( + ACL_target_user, + ACL_target_group, + ACL_target_user_Annonymous, + ACL_target_group_Annonymous, + ACL_target_group_Any, + ACL_rule, +) + from .rest_request_opt import ( RestRequestParams_POST, RestRequestParams_DELETE, @@ -174,7 +183,7 @@ class ResourceHandler( # reveal_type(next_resource) _next_resource = cast(_T_Resource, next_resource) # reveal_type(_next_resource) - # print(f"[DEBUG] next_resource = {type(next_resource).__name__}") + print(f"[DEBUG] next_resource = {type(next_resource).__name__}") if ( isinstance(_next_resource, RestResourceBase) @@ -194,7 +203,7 @@ class ResourceHandler( self.next_handler = next_resource_handler return next_resource_handler - # in the context of _find_resource, only resource real values can be retrieved + # in _find_resource context, only resource's real values can be retrieved raise RuntimeError("Wrong request") def _check_access_rights(self): @@ -455,28 +464,43 @@ class ResourceHandler_RestResourceBase( def _check_access_rights(self) -> None: super()._check_access_rights() - # print(f"{type(self).__name__}->_check_access_rights()") + print(f"{type(self).__name__}->_check_access_rights()") if self.req.get_resource_origin(0) == "/": return - # print("======") - # print(self.req.get_resource_origin(0)) + print("==================") + print(self.req.get_resource_origin(0)) # print(len(self.req.get_url_stack())) # print(self.resource._model_dump_excluded_) # print(type(self.resource)) # print(self.resource.exclude) if self.req.get_resource_origin(0) not in self.resource.model_fields: - raise RuntimeError(f"Unknown or not allowed field access detected: {self.req.get_url_stack()}") + raise RuntimeError(f"Unknown field access detected: {self.req.get_url_stack()}") + + self.resource.check_acl_access(self.req) + + if len(self.req.get_url_stack()) == 0: # destination reached + if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET: + raise RuntimeError(f"Not allowed READ access detected: {self.req.get_url_stack()}") + """ # not sure init_var has the expected behavior (read_only) + if self.resource.model_fields[self.req.get_resource_origin(0)].init_var is True and self.req.get_verb() in [ + rsrc_verb.POST, + rsrc_verb.PUT, + rsrc_verb.DELETE, + ]: + raise RuntimeError(f"Not allowed WRITE access detected: {self.req.get_url_stack()}") + """ def _handle_process_get(self, params) -> RestResourceBase: # print(f"{type(self).__name__}->_process_get()") # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") - # CASE 1: no more item in url_stack => we reached the endpoint + # CASE 1: no more item in url_stack => we reached the endpoint (operation) # So we are in a RestResourceBase instance and must return the content if len(self.req.get_url_stack()) == 0: + self.resource.check_acl_operation(self.req) for key, attr in self.resource.model_fields.items(): if key in self.resource._plugins_: if isinstance(self.resource._plugins_[key], ResourcePlugin_field): @@ -492,12 +516,12 @@ class ResourceHandler_RestResourceBase( # print(result) return self.resource - # CASE 2: specific case for root Node + # CASE 2: specific (operation) case for root Node # TODO: this must probably be merged with the previous bloc if self.req.get_resource_origin(0) == "/": return self.resource - # CASE 3: in between + # CASE 3: in between (access) value = getattr(self.resource, self.req.get_resource_origin(0)) key = self.req.get_resource_origin(0) @@ -522,6 +546,8 @@ class ResourceHandler_RestResourceBase( # print(f"{type(self).__name__}->_process_put()") # print(f"{type(self).__name__}->resource = {type(self.resource).__name__}") + self.resource.check_acl_operation(self.req, self.req.get_data()) + # creating a copy of the current resource _new_resrc = self.resource.copy() # updating values based on nex data diff --git a/src/pyrestresource/rest_types.py b/src/pyrestresource/rest_types.py index 17a95bd..55e7c12 100644 --- a/src/pyrestresource/rest_types.py +++ b/src/pyrestresource/rest_types.py @@ -39,10 +39,9 @@ _T_SupportedRESTFields = [ Path, IPv4Address, IPv4Network, + NoneType, ] -T_SupportedRESTFields = Union[ - UUID, str, int, float, bool, bytes, datetime, Path, IPv4Address, IPv4Network -] +T_SupportedRESTFields = Union[UUID, str, int, float, bool, bytes, datetime, Path, IPv4Address, IPv4Network, NoneType] TV_SupportedRESTFields = TypeVar( "TV_SupportedRESTFields", UUID, @@ -55,6 +54,7 @@ TV_SupportedRESTFields = TypeVar( Path, IPv4Address, IPv4Network, + NoneType, ) if get_origin(T_SupportedRESTFields) is not Union: @@ -68,12 +68,8 @@ T_FieldValue = Union[T_SupportedRESTFields, "RestResourceBase"] T_ListIndex = NewType("T_ListIndex", int) T_ListSize = NewType("T_ListSize", int) -T_DictKey = Union[ - UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network -] # datetime is removed because non-hashable -_T_DictKey = TypeVar( - "_T_DictKey", UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network -) +T_DictKey = Union[UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network] # datetime is removed because non-hashable +_T_DictKey = TypeVar("_T_DictKey", UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network) T_T_DictKey = type[T_DictKey] @@ -92,6 +88,7 @@ _T_DictValues = TypeVar( IPv4Address, IPv4Network, "RestResourceBase", + NoneType, ) T_T_FieldValue = type(T_FieldValue) diff --git a/test/test_rest_login.py b/test/test_rest_login.py index 2fdf8f1..32d2b1f 100644 --- a/test/test_rest_login.py +++ b/test/test_rest_login.py @@ -33,6 +33,7 @@ from src.pyrestresource import ( ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, ) +from src.pyrestresource import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule from pprint import pprint testdir_path = Path(__file__).parent.resolve() @@ -76,16 +77,19 @@ def init_classes(): return resource class Login(RestResourceBase): - username: Optional[str] = Field(None, exclude=True) - # username: Optional[str] = Field(None) - secret: Optional[str] = Field(None, exclude=True) + username: Optional[str] = Field(None) + secret: Optional[str] = Field( + None, + exclude=True, + ACL=[ + ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), + ACL_record(verbs=[rsrc_verb.GET, rsrc_verb.DELETE, rsrc_verb.POST], target=ACL_target_group_Any(), rule=ACL_rule.DENY), + ], + ) @register_rest_rootpoint class RootApp(RestResourceBase): - login: Login = Field( - default=Login(), - plugin=ResourcePlugin_Login, - ) + login: Login = Field(default=Login(), plugin=ResourcePlugin_Login) # this add the classes to globals to allow using them later on # => this is only for uinit-testing purpose and is not needed in real use @@ -115,14 +119,37 @@ class Test_RestAPI_LOGIN(unittest.TestCase): self.testapp = RootApp() def test_login(self): + """ result = self.testapp.process_request("/login", rsrc_verb.GET) - print(result) + print("*****************") + print(result.get_result()) + + result = self.testapp.process_request("/login/username", rsrc_verb.GET) + print("*****************") + print(result.get_result()) + + # result = self.testapp.process_request("/login/secret", rsrc_verb.GET) + # print("*****************") + # print(result.get_result()) + """ result = self.testapp.process_request("/login", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}') - print(result) + print("*****************") + print(result.get_result()) + """ result = self.testapp.process_request("/login", rsrc_verb.GET) - print(result) + print("*****************") + print(result.get_result()) + + result = self.testapp.process_request("/login/username", rsrc_verb.GET) + print("*****************") + print(result.get_result()) + + # result = self.testapp.process_request("/login/secret", rsrc_verb.GET) + # print("*****************") + # print(result.get_result()) + """ class Test_RestAPI_LOGIN_Web(unittest.TestCase):