implement first ACL version (lack dict support)

This commit is contained in:
cclecle
2023-11-03 13:11:38 +00:00
parent 2251b1d5e9
commit c3ff00e877
7 changed files with 250 additions and 61 deletions

View File

@@ -53,3 +53,4 @@ from .rest_resource_plugin import (
ResourcePlugin_RestResourceBase_default,
ResourcePlugin_dict_default,
)
from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
from pydantic import BaseModel
from enum import Enum, auto
from .rest_types import rsrc_verb
class ACL_target(BaseModel):
pass
class ACL_target_user(ACL_target):
name: str
class ACL_target_user_Annonymous(ACL_target):
name: str = "__ANNONYMOUS__"
class ACL_target_group(ACL_target):
name: str
class ACL_target_group_Annonymous(ACL_target):
name: str = "__ANNONYMOUS__"
class ACL_target_group_Any(ACL_target_group):
name: str = "__ANY__"
class ACL_rule(Enum):
ALLOW = auto()
DENY = auto()
class ACL_record(BaseModel):
verbs: list[rsrc_verb]
target: ACL_target
rule: ACL_rule

View File

@@ -25,6 +25,8 @@ from .rest_request_opt import (
_T_RestRequestParams_PUT,
)
from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group, ACL_target_group_Annonymous
class RequestFactory(
Generic[
@@ -118,6 +120,9 @@ class RestRequest(Generic[_T_RestRequestParams]):
self.url_stack_index: int
self.incoming_cookie: dict[str, str] = incoming_cookie
self.outgoing_cookie: dict[str, str] = outgoing_cookie
self.user: ACL_target_user = ACL_target_user_Annonymous()
self.group: ACL_target_group = ACL_target_group_Annonymous()
self.result: Optional[str] = None
# = or create a fresh one =
if url is None or verb is None or data is None:
@@ -141,6 +146,18 @@ class RestRequest(Generic[_T_RestRequestParams]):
self._saved_url_stack = self.url_stack.copy()
self.url_stack_index = 0
def set_result(self, result: str):
self.result = result
def get_result(self) -> Optional[str]:
return self.result
def set_user(self, user: ACL_target_user):
self.user: ACL_target_user = user
def set_group(self, group: ACL_target_group):
self.group: ACL_target_group = group
def update_ReqParams(self, type_request_params: type[_T_RestRequestParams]):
self.ReqParams = type_request_params(**self._saved_url_params)

View File

@@ -34,6 +34,16 @@ from .rest_resource_plugin import (
ResourcePlugin_dict,
)
from .rest_ACL import (
ACL_record,
ACL_target_user,
ACL_target_group,
ACL_target_user_Annonymous,
ACL_target_group_Annonymous,
ACL_target_group_Any,
ACL_rule,
)
from .rest_resource_walker import (
RestResourceWalkerFutureResult,
@@ -116,17 +126,26 @@ class RestResourceWalker_Sub_T_Dict__tree_init(RestResourceWalker_Sub_T_Dict):
self.resource.exclude = True
self.parent.resource.model_rebuild(force=True)
self.parent.annotation._ACL_record_[self.resource_name] = []
if (
isinstance(self.resource, FieldInfo)
and self.resource.json_schema_extra is not None
and type(self.resource.json_schema_extra) is dict
and "plugin" in self.resource.json_schema_extra
):
plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"]
if not isinstance(plugin_dict, ResourcePlugin_dict):
raise RuntimeError("Wrong plugin signature provided")
self.parent.annotation._plugins_[self.resource_name] = plugin_dict
# print("ADD DICT PLUGIN")
if "plugin" in self.resource.json_schema_extra:
plugin_dict: ResourcePlugin_dict = self.resource.json_schema_extra["plugin"]
if not isinstance(plugin_dict, ResourcePlugin_dict):
raise RuntimeError("Wrong plugin signature provided")
self.parent.annotation._plugins_[self.resource_name] = plugin_dict
# print("ADD DICT PLUGIN")
if "ACL" in self.resource.json_schema_extra:
if isinstance(self.resource.json_schema_extra["ACL"], list):
print(f"found ACL (Dict): {self.resource.json_schema_extra['ACL']}")
self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"]
else:
raise RuntimeError("ACL must be a list()")
else:
raise RuntimeError("dict must be contained in a RestResourceBase")
@@ -141,6 +160,9 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi
# pprint.pprint(self.resource.json_schema_extra)
# pprint.pprint(self.annotation)
# pprint.pprint(self.resource.exclude)
self.parent.annotation._ACL_record_[self.resource_name] = []
if (
isinstance(self.resource, FieldInfo)
and self.resource.json_schema_extra is not None
@@ -153,13 +175,20 @@ class RestResourceWalker_Sub_RestFields__tree_init(RestResourceWalker_Sub_RestFi
raise RuntimeError(f"Only one primary key is allowed {self.parent.resource_name}.{self.resource_name}")
self.parent.annotation._primary_key_ = self.resource_name
if "plugin" in self.resource.json_schema_extra and self.resource.json_schema_extra["plugin"]:
if "plugin" in self.resource.json_schema_extra:
plugin_field: ResourcePlugin_field = self.resource.json_schema_extra["plugin"]
if not isinstance(plugin_field, ResourcePlugin_field):
raise RuntimeError("Wrong plugin signature provided")
self.parent.annotation._plugins_[self.resource_name] = plugin_field
# print("ADD FIELD PLUGIN")
if "ACL" in self.resource.json_schema_extra:
if isinstance(self.resource.json_schema_extra["ACL"], list):
print(f"found ACL (Field): {self.resource.json_schema_extra['ACL']}")
self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"]
else:
raise RuntimeError("ACL must be a list()")
else:
raise RuntimeError("fields must be contained in a RestResourceBase")
@@ -171,24 +200,33 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_
setattr(self.annotation, "_model_dump_excluded_", {})
setattr(self.annotation, "_primary_key_", None)
setattr(self.annotation, "_plugins_", {})
setattr(self.annotation, "_ACL_record_", {})
# preprocessing types / structure
if self.parent is not None and isinstance(self.parent, RestResourceWalker_Sub_RestResourceBase):
self.parent.annotation._model_dump_excluded_[self.resource_name] = True
self.resource.exclude = True
self.parent.resource.model_rebuild(force=True)
self.parent.annotation._ACL_record_[self.resource_name] = []
if (
isinstance(self.resource, FieldInfo)
and self.resource.json_schema_extra is not None
and type(self.resource.json_schema_extra) is dict
and "plugin" in self.resource.json_schema_extra
):
plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"]
if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase):
raise RuntimeError("Wrong plugin signature provided")
self.parent.annotation._plugins_[self.resource_name] = plugin_resource
# print("ADD RESOURCE PLUGIN")
if "plugin" in self.resource.json_schema_extra:
plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"]
if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase):
raise RuntimeError("Wrong plugin signature provided")
self.parent.annotation._plugins_[self.resource_name] = plugin_resource
# print("ADD RESOURCE PLUGIN")
if "ACL" in self.resource.json_schema_extra:
if isinstance(self.resource.json_schema_extra["ACL"], list):
print(f"found ACL (Resource): {self.resource.json_schema_extra['ACL']}")
self.parent.annotation._ACL_record_[self.resource_name] = self.resource.json_schema_extra["ACL"]
else:
raise RuntimeError("ACL must be a list()")
class RestResourceWalker_Root__tree_init(RestResourceWalker_Root):
@@ -213,10 +251,55 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
_plugins_: ClassVar[
dict[
str,
ResourcePlugin_field | ResourcePlugin_RestResourceBase | ResourcePlugin_dict,
list[ACL_record],
]
] = {}
_request: Optional[RestRequest] = None
_ACL_record_: ClassVar[
dict[
str,
ACL_record,
]
] = {}
def _check_acl(self, user: ACL_target_user, group: ACL_target_group, verb: rsrc_verb, field: str):
print(f"evaluate self ACLs rule: {self._ACL_record_}")
if verb is rsrc_verb.GET and self.model_fields[field].exclude is True:
print("ALLOWED (excluded field)")
return
for acl in self._ACL_record_[field]:
print(f"evaluate ACL rule: {acl}")
if verb in acl.verbs:
if isinstance(acl.target, ACL_target_user):
if user == acl.target:
if acl.rule is ACL_rule.ALLOW:
print("ALLOWED (user)")
return
raise RuntimeError(f"Not allowed access detected: {field}")
elif isinstance(acl.target, ACL_target_group):
if group == acl.target or acl.target == ACL_target_group_Any():
if acl.rule is ACL_rule.ALLOW:
print("ALLOWED (group)")
return
raise RuntimeError(f"Not allowed access detected: {field}")
else:
raise RuntimeError(f"Wrong ACL target type: {field}")
print("ALLOWED (Default)")
def check_acl_access(self, request: RestRequest) -> None:
"""Check ACL on requested field access"""
self._check_acl(request.user, request.group, request.get_verb(), request.get_resource_origin(0))
def check_acl_operation(self, request: RestRequest, new_data: Optional[dict[str, _T_SupportedRESTFields]]) -> None:
"""Check ACL on requested field operation (involving checking sub-fields)"""
if request.get_verb() is rsrc_verb.GET:
for key in self.model_fields.keys():
self._check_acl(request.user, request.group, rsrc_verb.GET, key)
elif request.get_verb() is rsrc_verb.PUT:
for key in new_data.keys():
if key in self.model_fields:
self._check_acl(request.user, request.group, rsrc_verb.PUT, key)
else:
raise RuntimeError("Incompatible verb")
def update(self, **new_data):
for field, value in new_data.items():
@@ -251,12 +334,11 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
body = await self.read_body(receive)
verb = rsrc_verb[scope["method"]]
self._request = None
result = self.process_request(
request: RestRequest = self.process_request(
scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8")
)
assert self._request != None
assert request != None
status = 200
if verb in (rsrc_verb.POST, rsrc_verb.PUT):
@@ -270,7 +352,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
],
}
for name, value in self._request.outgoing_cookie.items():
for name, value in request.outgoing_cookie.items():
header_resp["headers"].append(["Set-Cookie", f"{name}={value}"])
# print("----SENT HEADER ---")
@@ -278,8 +360,8 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
await send(header_resp)
body = None
if result:
body = result.encode("utf-8")
if request.get_result():
body = request.get_result().encode("utf-8")
await send(
{
@@ -294,7 +376,7 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
verb: rsrc_verb = rsrc_verb.GET,
data_json: Optional[str] = None,
query_string: Optional[str] = None,
) -> Optional[str]:
) -> RestRequest:
from .rest_resource_handler import (
ResourceHandler,
ResourceHandler_RestResourceBase,
@@ -304,22 +386,20 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
if data_json:
data = json.loads(data_json)
ressource: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string)
self._request = ressource.get_request()
ressource_handler: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string)
result = ressource.process_verb()
request: RestRequest = ressource_handler.get_request()
assert request != None
result = ressource_handler.process_verb()
# print("OOO")
# print(type(self)._resp_cookies)
# print("OOO2")
if isinstance(result, RestResourceBase):
# exclude: Optional[dict[str, bool]] = None
# raw_exclude = RestResourceWalker_Root__tree_exclude(result).process()
# exclude = next(iter(raw_exclude.values()))
# return json.dumps(result.model_dump(mode="json", exclude=exclude))
return json.dumps(result.model_dump(mode="json"))
request.set_result(json.dumps(result.model_dump(mode="json")))
elif result is not None:
request.set_result(json.dumps(result, cls=_JSONEncoder))
if result is not None:
return json.dumps(result, cls=_JSONEncoder)
return None
return request

View File

@@ -19,6 +19,15 @@ from .rest_resource_plugin import (
ResourcePlugin_RestResourceBase,
)
from .rest_ACL import (
ACL_target_user,
ACL_target_group,
ACL_target_user_Annonymous,
ACL_target_group_Annonymous,
ACL_target_group_Any,
ACL_rule,
)
from .rest_request_opt import (
RestRequestParams_POST,
RestRequestParams_DELETE,
@@ -174,7 +183,7 @@ class ResourceHandler(
# reveal_type(next_resource)
_next_resource = cast(_T_Resource, next_resource)
# reveal_type(_next_resource)
# print(f"[DEBUG] next_resource = {type(next_resource).__name__}")
print(f"[DEBUG] next_resource = {type(next_resource).__name__}")
if (
isinstance(_next_resource, RestResourceBase)
@@ -194,7 +203,7 @@ class ResourceHandler(
self.next_handler = next_resource_handler
return next_resource_handler
# in the context of _find_resource, only resource real values can be retrieved
# in _find_resource context, only resource's real values can be retrieved
raise RuntimeError("Wrong request")
def _check_access_rights(self):
@@ -455,28 +464,43 @@ class ResourceHandler_RestResourceBase(
def _check_access_rights(self) -> None:
super()._check_access_rights()
# print(f"{type(self).__name__}->_check_access_rights()")
print(f"{type(self).__name__}->_check_access_rights()")
if self.req.get_resource_origin(0) == "/":
return
# print("======")
# print(self.req.get_resource_origin(0))
print("==================")
print(self.req.get_resource_origin(0))
# print(len(self.req.get_url_stack()))
# print(self.resource._model_dump_excluded_)
# print(type(self.resource))
# print(self.resource.exclude)
if self.req.get_resource_origin(0) not in self.resource.model_fields:
raise RuntimeError(f"Unknown or not allowed field access detected: {self.req.get_url_stack()}")
raise RuntimeError(f"Unknown field access detected: {self.req.get_url_stack()}")
self.resource.check_acl_access(self.req)
if len(self.req.get_url_stack()) == 0: # destination reached
if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET:
raise RuntimeError(f"Not allowed READ access detected: {self.req.get_url_stack()}")
""" # not sure init_var has the expected behavior (read_only)
if self.resource.model_fields[self.req.get_resource_origin(0)].init_var is True and self.req.get_verb() in [
rsrc_verb.POST,
rsrc_verb.PUT,
rsrc_verb.DELETE,
]:
raise RuntimeError(f"Not allowed WRITE access detected: {self.req.get_url_stack()}")
"""
def _handle_process_get(self, params) -> RestResourceBase:
# print(f"{type(self).__name__}->_process_get()")
# print(f"{type(self).__name__}->resource = {type(self.resource).__name__}")
# CASE 1: no more item in url_stack => we reached the endpoint
# CASE 1: no more item in url_stack => we reached the endpoint (operation)
# So we are in a RestResourceBase instance and must return the content
if len(self.req.get_url_stack()) == 0:
self.resource.check_acl_operation(self.req)
for key, attr in self.resource.model_fields.items():
if key in self.resource._plugins_:
if isinstance(self.resource._plugins_[key], ResourcePlugin_field):
@@ -492,12 +516,12 @@ class ResourceHandler_RestResourceBase(
# print(result)
return self.resource
# CASE 2: specific case for root Node
# CASE 2: specific (operation) case for root Node
# TODO: this must probably be merged with the previous bloc
if self.req.get_resource_origin(0) == "/":
return self.resource
# CASE 3: in between
# CASE 3: in between (access)
value = getattr(self.resource, self.req.get_resource_origin(0))
key = self.req.get_resource_origin(0)
@@ -522,6 +546,8 @@ class ResourceHandler_RestResourceBase(
# print(f"{type(self).__name__}->_process_put()")
# print(f"{type(self).__name__}->resource = {type(self.resource).__name__}")
self.resource.check_acl_operation(self.req, self.req.get_data())
# creating a copy of the current resource
_new_resrc = self.resource.copy()
# updating values based on nex data

View File

@@ -39,10 +39,9 @@ _T_SupportedRESTFields = [
Path,
IPv4Address,
IPv4Network,
NoneType,
]
T_SupportedRESTFields = Union[
UUID, str, int, float, bool, bytes, datetime, Path, IPv4Address, IPv4Network
]
T_SupportedRESTFields = Union[UUID, str, int, float, bool, bytes, datetime, Path, IPv4Address, IPv4Network, NoneType]
TV_SupportedRESTFields = TypeVar(
"TV_SupportedRESTFields",
UUID,
@@ -55,6 +54,7 @@ TV_SupportedRESTFields = TypeVar(
Path,
IPv4Address,
IPv4Network,
NoneType,
)
if get_origin(T_SupportedRESTFields) is not Union:
@@ -68,12 +68,8 @@ T_FieldValue = Union[T_SupportedRESTFields, "RestResourceBase"]
T_ListIndex = NewType("T_ListIndex", int)
T_ListSize = NewType("T_ListSize", int)
T_DictKey = Union[
UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network
] # datetime is removed because non-hashable
_T_DictKey = TypeVar(
"_T_DictKey", UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network
)
T_DictKey = Union[UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network] # datetime is removed because non-hashable
_T_DictKey = TypeVar("_T_DictKey", UUID, str, int, float, bool, bytes, Path, IPv4Address, IPv4Network)
T_T_DictKey = type[T_DictKey]
@@ -92,6 +88,7 @@ _T_DictValues = TypeVar(
IPv4Address,
IPv4Network,
"RestResourceBase",
NoneType,
)
T_T_FieldValue = type(T_FieldValue)

View File

@@ -33,6 +33,7 @@ from src.pyrestresource import (
ResourcePlugin_field_default,
ResourcePlugin_RestResourceBase_default,
)
from src.pyrestresource import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule
from pprint import pprint
testdir_path = Path(__file__).parent.resolve()
@@ -76,16 +77,19 @@ def init_classes():
return resource
class Login(RestResourceBase):
username: Optional[str] = Field(None, exclude=True)
# username: Optional[str] = Field(None)
secret: Optional[str] = Field(None, exclude=True)
username: Optional[str] = Field(None)
secret: Optional[str] = Field(
None,
exclude=True,
ACL=[
ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW),
ACL_record(verbs=[rsrc_verb.GET, rsrc_verb.DELETE, rsrc_verb.POST], target=ACL_target_group_Any(), rule=ACL_rule.DENY),
],
)
@register_rest_rootpoint
class RootApp(RestResourceBase):
login: Login = Field(
default=Login(),
plugin=ResourcePlugin_Login,
)
login: Login = Field(default=Login(), plugin=ResourcePlugin_Login)
# this add the classes to globals to allow using them later on
# => this is only for uinit-testing purpose and is not needed in real use
@@ -115,14 +119,37 @@ class Test_RestAPI_LOGIN(unittest.TestCase):
self.testapp = RootApp()
def test_login(self):
"""
result = self.testapp.process_request("/login", rsrc_verb.GET)
print(result)
print("*****************")
print(result.get_result())
result = self.testapp.process_request("/login/username", rsrc_verb.GET)
print("*****************")
print(result.get_result())
# result = self.testapp.process_request("/login/secret", rsrc_verb.GET)
# print("*****************")
# print(result.get_result())
"""
result = self.testapp.process_request("/login", rsrc_verb.PUT, '{"username":"chacha","secret":"123456"}')
print(result)
print("*****************")
print(result.get_result())
"""
result = self.testapp.process_request("/login", rsrc_verb.GET)
print(result)
print("*****************")
print(result.get_result())
result = self.testapp.process_request("/login/username", rsrc_verb.GET)
print("*****************")
print(result.get_result())
# result = self.testapp.process_request("/login/secret", rsrc_verb.GET)
# print("*****************")
# print(result.get_result())
"""
class Test_RestAPI_LOGIN_Web(unittest.TestCase):