continuing implementation of login and session

This commit is contained in:
cclecle
2023-11-05 15:38:08 +00:00
parent 346ff649ec
commit f00cf7b4b2
11 changed files with 325 additions and 113 deletions

View File

@@ -1,6 +1,7 @@
eclipse.preferences.version=1 eclipse.preferences.version=1
encoding//src/pyrestresource/__init__.py=utf-8 encoding//src/pyrestresource/__init__.py=utf-8
encoding//src/pyrestresource/__metadata__.py=utf-8 encoding//src/pyrestresource/__metadata__.py=utf-8
encoding//src/pyrestresource/rest_login.py=utf-8
encoding//src/pyrestresource/rest_resource.py=utf-8 encoding//src/pyrestresource/rest_resource.py=utf-8
encoding//src/pyrestresource/rest_resource_handler_walker.py=utf-8 encoding//src/pyrestresource/rest_resource_handler_walker.py=utf-8
encoding/<project>=UTF-8 encoding/<project>=UTF-8

View File

@@ -54,3 +54,4 @@ from .rest_resource_plugin import (
ResourcePlugin_dict_default, ResourcePlugin_dict_default,
) )
from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule
from .rest_login import RestResourceBaseLogin, UserLogin

View File

@@ -15,3 +15,17 @@ class _JSONEncoder(json.JSONEncoder):
# if the obj is uuid, we simply return the value of uuid # if the obj is uuid, we simply return the value of uuid
return str(o) return str(o)
return json.JSONEncoder.default(self, o) return json.JSONEncoder.default(self, o)
def parse_dict_cookies(cookies: str) -> dict[str, str]:
result = {}
for item in cookies.split(";"):
item = item.strip()
if not item:
continue
if "=" not in item:
result[item] = None
continue
name, value = item.split("=", 1)
result[name] = value
return result

View File

@@ -22,10 +22,6 @@ class ACL_target_group(ACL_target):
name: str name: str
class ACL_target_group_Annonymous(ACL_target):
name: str = "__ANNONYMOUS__"
class ACL_target_group_Any(ACL_target_group): class ACL_target_group_Any(ACL_target_group):
name: str = "__ANY__" name: str = "__ANY__"

View File

@@ -0,0 +1,127 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# pyrestresource(c) by chacha
#
# pyrestresource is licensed under a
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License.
#
# You should have received a copy of the license along with this
# work. If not, see <https://creativecommons.org/licenses/by-nc-sa/4.0/>.
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
"""CLI interface module"""
from __future__ import annotations
from typing import Optional, ClassVar, TYPE_CHECKING
from secrets import token_hex, compare_digest
from datetime import datetime
from pydantic import BaseModel, Field
from .rest_types import rsrc_verb
from .rest_resource import RestResourceBase
from .rest_request import RestRequest, RestRequestParams_GET
from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule
if TYPE_CHECKING or True:
from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default
class UserLogin(BaseModel):
username: str
secret: str
class UserSession(BaseModel):
last_update: datetime
user_login: UserLogin
host: Optional[str]
class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default):
ar_UserLogin: list[UserLogin] = []
def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login:
print("hook GET")
print(resource)
print(params)
return resource
def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login:
print("hook PUT")
# print(self.get_ar_userlogin())
print(resource.username)
print(resource.secret)
token = self.user_login(resource.username, resource.secret)
self.set_resp_cookie_value("Authorization", f"Bearer {token}")
return resource
class Login(RestResourceBase):
username: Optional[str] = Field(None)
secret: Optional[str] = Field(
None,
exclude=True,
ACL=[
ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW),
ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY),
],
)
class RestResourceBaseLogin(RestResourceBase):
_ar_user_login: ClassVar[list[UserLogin]] = []
_ar_user_session: dict[str, UserSession] = {}
_max_session_time_minutes: ClassVar[int] = 20
login: Login = Field(default=Login(), plugin=ResourcePlugin_Login)
def _process_request_session(self, request: RestRequest) -> None:
auth_cookie = request.get_cookie("Authorization")
if auth_cookie != None:
if auth_cookie in self._ar_user_session:
print("USER SESSION FOUND !")
print(self._ar_user_session[auth_cookie].user_login.username)
print(auth_cookie)
time_diff_min = (datetime.now() - self._ar_user_session[auth_cookie].last_update).total_seconds() / 60
if time_diff_min > self._max_session_time_minutes:
del self._ar_user_session[auth_cookie]
raise RuntimeError("session timeout ! (session reseted)")
request.set_user(self._ar_user_session[auth_cookie].user_login.username)
return
print("Invalid session")
return
print("non-connected user")
def user_login(self, user_name: str, user_secret: str, request: RestRequest) -> str:
already_failed: bool = False
for iter_user_login in self._ar_user_login:
username_ok: bool = compare_digest(user_name, iter_user_login.username)
secret_ok: bool = compare_digest(user_secret, iter_user_login.secret)
if username_ok is True:
if secret_ok is True and not already_failed:
return self._register_user_session(iter_user_login, request)
else:
already_failed = True
else:
pass
pass
if already_failed:
raise RuntimeError("Wrong auth") # TODO: specific exception
def _register_user_session(self, user_login: UserLogin, request: RestRequest) -> str:
token = token_hex(16)
new_user_session = UserSession(last_update=datetime.now(), user_login=user_login, host=request.get_host())
self._ar_user_session[f"Bearer {token}"] = new_user_session
return token

View File

@@ -3,11 +3,14 @@
from __future__ import annotations from __future__ import annotations
from typing import ( from typing import (
Any,
Optional, Optional,
Generic, Generic,
) )
from re import sub from re import sub
from urllib.parse import urlparse, parse_qs from urllib.parse import urlparse, parse_qs
from http.cookies import SimpleCookie
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typeguard import check_type from typeguard import check_type
@@ -26,7 +29,8 @@ from .rest_request_opt import (
_T_RestRequestParams_PUT, _T_RestRequestParams_PUT,
) )
from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group, ACL_target_group_Annonymous from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group
from .helpers import parse_dict_cookies
class RequestFactory( class RequestFactory(
@@ -114,6 +118,8 @@ class RestRequest(Generic[_T_RestRequestParams]):
self.url: str self.url: str
self.verb: rsrc_verb self.verb: rsrc_verb
self.data: dict self.data: dict
self.raw_headers: list[Any]
self.headers: dict[str, None | str | dict[str, None | str]] = {"host": None, "cookie": {}}
self._saved_url_params: dict self._saved_url_params: dict
self.ReqParams: _T_RestRequestParams = type_request_params() self.ReqParams: _T_RestRequestParams = type_request_params()
self.url_stack: list[str] self.url_stack: list[str]
@@ -122,7 +128,7 @@ class RestRequest(Generic[_T_RestRequestParams]):
self.incoming_cookie: dict[str, str] = incoming_cookie self.incoming_cookie: dict[str, str] = incoming_cookie
self.outgoing_cookie: dict[str, str] = outgoing_cookie self.outgoing_cookie: dict[str, str] = outgoing_cookie
self.user: ACL_target_user = ACL_target_user_Annonymous() self.user: ACL_target_user = ACL_target_user_Annonymous()
self.group: ACL_target_group = ACL_target_group_Annonymous() self.groups: list[ACL_target_group] = []
self.result: Optional[str] = None self.result: Optional[str] = None
# = or create a fresh one = # = or create a fresh one =
@@ -151,6 +157,24 @@ class RestRequest(Generic[_T_RestRequestParams]):
self._saved_url_stack = self.url_stack.copy() self._saved_url_stack = self.url_stack.copy()
self.url_stack_index = 0 self.url_stack_index = 0
def set_headers(self, headers: list[Any]) -> None:
self.raw_headers = headers
for elem in self.raw_headers:
if elem[0] == b"host":
self.headers["host"] = elem[1].decode("utf-8")
# elif elem[0] == b"user-agent":
# self.headers["user-agent"] = elem[1].decode("utf-8")
elif elem[0] == b"cookie":
self.headers["cookie"] = parse_dict_cookies(elem[1].decode("utf-8"))
def get_cookie(self, key: str) -> str | None:
if key not in self.headers["cookie"]:
return None
return self.headers["cookie"][key]
def get_host(self) -> str:
print(self.headers["host"])
def set_result(self, result: str): def set_result(self, result: str):
self.result = result self.result = result
@@ -160,8 +184,8 @@ class RestRequest(Generic[_T_RestRequestParams]):
def set_user(self, user: ACL_target_user): def set_user(self, user: ACL_target_user):
self.user: ACL_target_user = user self.user: ACL_target_user = user
def set_group(self, group: ACL_target_group): def add_group(self, group: ACL_target_group):
self.group: ACL_target_group = group self.groups.append(group)
def update_ReqParams(self, type_request_params: type[_T_RestRequestParams]): def update_ReqParams(self, type_request_params: type[_T_RestRequestParams]):
self.ReqParams = type_request_params(**self._saved_url_params) self.ReqParams = type_request_params(**self._saved_url_params)

View File

@@ -15,6 +15,7 @@ from __future__ import annotations
from abc import ABC from abc import ABC
from typing import ( from typing import (
Any,
cast, cast,
ClassVar, ClassVar,
get_args, get_args,
@@ -39,7 +40,6 @@ from .rest_ACL import (
ACL_target_user, ACL_target_user,
ACL_target_group, ACL_target_group,
ACL_target_user_Annonymous, ACL_target_user_Annonymous,
ACL_target_group_Annonymous,
ACL_target_group_Any, ACL_target_group_Any,
ACL_rule, ACL_rule,
) )
@@ -219,8 +219,8 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_
): ):
if "plugin" in self.resource.json_schema_extra: if "plugin" in self.resource.json_schema_extra:
plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"] plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"]
if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase): if not issubclass(plugin_resource, ResourcePlugin_RestResourceBase):
raise RuntimeError("Wrong plugin signature provided") raise RuntimeError(f"Wrong plugin signature provided for {plugin_resource} : {type(plugin_resource)}")
self.parent.annotation._plugins_[self.resource_name] = plugin_resource self.parent.annotation._plugins_[self.resource_name] = plugin_resource
# print("ADD RESOURCE PLUGIN") # print("ADD RESOURCE PLUGIN")
@@ -246,7 +246,7 @@ def register_rest_rootpoint(klass: type[RestResourceBase]):
class RestResourceBase(ABC, BaseModel, validate_assignment=True): class RestResourceBase(ABC, BaseModel, validate_assignment=True):
_resp_cookies: ClassVar[dict[str, str]] = dict() # _resp_cookies: ClassVar[dict[str, str]] = {}
_dict_key_type_: ClassVar[dict[str, T_T_DictKey]] = {} _dict_key_type_: ClassVar[dict[str, T_T_DictKey]] = {}
_dict_value_type_: ClassVar[dict[str, T_T_DictValues]] = {} _dict_value_type_: ClassVar[dict[str, T_T_DictValues]] = {}
_model_dump_excluded_: ClassVar[dict[str, bool]] = {} _model_dump_excluded_: ClassVar[dict[str, bool]] = {}
@@ -264,43 +264,45 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
] ]
] = {} ] = {}
def _check_acl(self, user: ACL_target_user, group: ACL_target_group, verb: rsrc_verb, field: str, is_self: bool = True): def _check_acl(self, user: str, groups: list[ACL_target_group], verb: rsrc_verb, field: str, is_self: bool = True):
# print(f"evaluate self ACLs rule: {self._ACL_record_}") print(f"evaluate self ACLs rule: {self._ACL_record_}")
print(f"user: {user}")
print(f"groups: {groups}")
if is_self and verb is rsrc_verb.GET and self.model_fields[field].exclude is True: if is_self and verb is rsrc_verb.GET and self.model_fields[field].exclude is True:
# print("ALLOWED (excluded field)") # print("ALLOWED (excluded field)")
return return
for acl in self._ACL_record_[field]: for acl in self._ACL_record_[field]:
# print(f"evaluate ACL rule: {acl}") print(f"evaluate ACL rule: {acl}")
if verb in acl.verbs: if verb in acl.verbs:
if isinstance(acl.target, ACL_target_user): if isinstance(acl.target, ACL_target_user):
if user == acl.target: if user == acl.target.name:
if acl.rule is ACL_rule.ALLOW: if acl.rule is ACL_rule.ALLOW:
# print("ALLOWED (user)") print("ALLOWED (user)")
return return
raise RuntimeError(f"Not allowed access detected: {field}") raise RuntimeError(f"Not allowed access detected: {field}")
elif isinstance(acl.target, ACL_target_group): elif isinstance(acl.target, ACL_target_group):
if group == acl.target or acl.target == ACL_target_group_Any(): if acl.target.name in groups or isinstance(acl.target, ACL_target_group_Any):
if acl.rule is ACL_rule.ALLOW: if acl.rule is ACL_rule.ALLOW:
# print("ALLOWED (group)") print("ALLOWED (group)")
return return
raise RuntimeError(f"Not allowed access detected: {field}") raise RuntimeError(f"Not allowed access detected: {field}")
else: else:
raise RuntimeError(f"Wrong ACL target type: {field}") raise RuntimeError(f"Wrong ACL target type: {field}")
# print("ALLOWED (Default)") print("ALLOWED (Default)")
def check_acl_field(self, request: RestRequest, req_index: int = 0) -> None: def check_acl_field(self, request: RestRequest, req_index: int = 0) -> None:
"""Check ACL on requested field access""" """Check ACL on requested field access"""
self._check_acl(request.user, request.group, request.get_verb(), request.get_resource_origin(req_index), False) self._check_acl(request.user, request.groups, request.get_verb(), request.get_resource_origin(req_index), False)
def check_acl_self(self, request: RestRequest, new_data: Optional[dict[str, _T_SupportedRESTFields]]) -> None: def check_acl_self(self, request: RestRequest, new_data: Optional[dict[str, _T_SupportedRESTFields]]) -> None:
"""Check ACL on requested field operation (involving checking sub-fields)""" """Check ACL on requested field operation (involving checking sub-fields)"""
if request.get_verb() is rsrc_verb.GET: if request.get_verb() is rsrc_verb.GET:
for key in self.model_fields.keys(): for key in self.model_fields.keys():
self._check_acl(request.user, request.group, rsrc_verb.GET, key) self._check_acl(request.user, request.groups, rsrc_verb.GET, key)
elif request.get_verb() is rsrc_verb.PUT: elif request.get_verb() is rsrc_verb.PUT:
for key in new_data.keys(): for key in new_data.keys():
if key in self.model_fields: if key in self.model_fields:
self._check_acl(request.user, request.group, rsrc_verb.PUT, key) self._check_acl(request.user, request.groups, rsrc_verb.PUT, key)
else: else:
raise RuntimeError("Incompatible verb") raise RuntimeError("Incompatible verb")
@@ -324,21 +326,24 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
async def __call__(self, scope, receive, send): async def __call__(self, scope, receive, send):
assert scope["type"] == "http" assert scope["type"] == "http"
method = scope["method"] method = scope["method"]
assert method in ["GET", "DELETE", "PUT", "POST"] assert method in ["GET", "DELETE", "PUT", "POST"]
if b"content-type" in scope["headers"]: if b"content-type" in scope["headers"]:
assert scope["headers"][b"content-type"] == b"application/json" assert scope["headers"][b"content-type"] == b"application/json"
# import pprint import pprint
# print("----REC HEADER ---") print("----REC HEADER ---")
# pprint.pprint(scope["headers"]) pprint.pprint(scope["headers"])
body = await self.read_body(receive) body = await self.read_body(receive)
verb = rsrc_verb[scope["method"]] verb = rsrc_verb[scope["method"]]
request: RestRequest = self.process_request( request: RestRequest = self.process_request(
scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8") scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8"), scope["headers"]
) )
assert request != None assert request != None
@@ -373,12 +378,16 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
} }
) )
def _process_request_session(self, request: RestRequest) -> None:
pass
def process_request( def process_request(
self, self,
url: str, url: str,
verb: rsrc_verb = rsrc_verb.GET, verb: rsrc_verb = rsrc_verb.GET,
data_json: Optional[str] = None, data_json: Optional[str] = None,
query_string: Optional[str] = None, query_string: Optional[str] = None,
headers: Optional[list[Any]] = None,
) -> RestRequest: ) -> RestRequest:
from .rest_resource_handler import ( from .rest_resource_handler import (
ResourceHandler, ResourceHandler,
@@ -389,11 +398,16 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
if data_json: if data_json:
data = json.loads(data_json) data = json.loads(data_json)
# creating the root handler
ressource_handler: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string) ressource_handler: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string)
# preparing request & session
request: RestRequest = ressource_handler.get_request() request: RestRequest = ressource_handler.get_request()
assert request != None if headers is not None:
request.set_headers(headers)
self._process_request_session(request)
# processing the verb
result = ressource_handler.process_verb() result = ressource_handler.process_verb()
# print("OOO") # print("OOO")

View File

@@ -23,7 +23,6 @@ from .rest_ACL import (
ACL_target_user, ACL_target_user,
ACL_target_group, ACL_target_group,
ACL_target_user_Annonymous, ACL_target_user_Annonymous,
ACL_target_group_Annonymous,
ACL_target_group_Any, ACL_target_group_Any,
ACL_rule, ACL_rule,
) )
@@ -101,6 +100,7 @@ class ResourceHandler(
self.next_handler: Optional[ResourceHandler] = None self.next_handler: Optional[ResourceHandler] = None
self.saved_url: list[str] = [] self.saved_url: list[str] = []
self.resource: _T_Resource = resource self.resource: _T_Resource = resource
self.root_resource: _T_Resource = resource if prev_handler is None else prev_handler.root_resource
self.req: RestRequest self.req: RestRequest
if prev_handler is not None: if prev_handler is not None:
self.prev_handler = prev_handler self.prev_handler = prev_handler
@@ -484,14 +484,6 @@ class ResourceHandler_RestResourceBase(
if len(self.req.get_url_stack()) == 0: # destination reached if len(self.req.get_url_stack()) == 0: # destination reached
if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET: if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET:
raise RuntimeError(f"Not allowed READ access detected: {self.req.get_url_stack()}") raise RuntimeError(f"Not allowed READ access detected: {self.req.get_url_stack()}")
""" # not sure init_var has the expected behavior (read_only)
if self.resource.model_fields[self.req.get_resource_origin(0)].init_var is True and self.req.get_verb() in [
rsrc_verb.POST,
rsrc_verb.PUT,
rsrc_verb.DELETE,
]:
raise RuntimeError(f"Not allowed WRITE access detected: {self.req.get_url_stack()}")
"""
def _handle_process_get(self, params) -> RestResourceBase: def _handle_process_get(self, params) -> RestResourceBase:
# print(f"{type(self).__name__}->_process_get()") # print(f"{type(self).__name__}->_process_get()")
@@ -504,11 +496,15 @@ class ResourceHandler_RestResourceBase(
for key, attr in self.resource.model_fields.items(): for key, attr in self.resource.model_fields.items():
if key in self.resource._plugins_: if key in self.resource._plugins_:
if isinstance(self.resource._plugins_[key], ResourcePlugin_field): if isinstance(self.resource._plugins_[key], ResourcePlugin_field):
plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, self.resource._plugins_[key](self.req)) plugin_field: ResourcePlugin_field = cast(
ResourcePlugin_field, self.resource._plugins_[key](self.req, self.root_resource)
)
value = getattr(self.resource, key) value = getattr(self.resource, key)
setattr(self.resource, key, plugin_field.handle_field_get(value, params)) setattr(self.resource, key, plugin_field.handle_field_get(value, params))
elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase):
plugin_field: ResourcePlugin_field = cast(ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req)) plugin_field: ResourcePlugin_field = cast(
ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource)
)
value = getattr(self.resource, key) value = getattr(self.resource, key)
setattr(self.resource, key, plugin_field.handle_resource_get(value, params)) setattr(self.resource, key, plugin_field.handle_resource_get(value, params))
@@ -530,14 +526,14 @@ class ResourceHandler_RestResourceBase(
if isinstance(self.resource._plugins_[key], ResourcePlugin_field): if isinstance(self.resource._plugins_[key], ResourcePlugin_field):
plugin_rsrc: ResourcePlugin_RestResourceBase = cast( plugin_rsrc: ResourcePlugin_RestResourceBase = cast(
ResourcePlugin_RestResourceBase, ResourcePlugin_RestResourceBase,
self.resource._plugins_[key](self.req), self.resource._plugins_[key](self.req, self.root_resource),
) )
value = plugin_rsrc.handle_field_get(value, params) value = plugin_rsrc.handle_field_get(value, params)
elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase): elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase):
plugin_rsrc: ResourcePlugin_RestResourceBase = cast( plugin_rsrc: ResourcePlugin_RestResourceBase = cast(
ResourcePlugin_RestResourceBase, ResourcePlugin_RestResourceBase,
self.resource._plugins_[key](self.req), self.resource._plugins_[key](self.req, self.root_resource),
) )
value = plugin_rsrc.handle_resource_get(value, params) value = plugin_rsrc.handle_resource_get(value, params)
@@ -559,7 +555,9 @@ class ResourceHandler_RestResourceBase(
for key, attr in _new_resrc.model_fields.items(): for key, attr in _new_resrc.model_fields.items():
if key in _new_resrc._plugins_: if key in _new_resrc._plugins_:
if isinstance(_new_resrc._plugins_[key], ResourcePlugin_field): if isinstance(_new_resrc._plugins_[key], ResourcePlugin_field):
plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, _new_resrc._plugins_[key](self.req)) plugin_field: ResourcePlugin_field = cast(
ResourcePlugin_field, _new_resrc._plugins_[key](self.req, self.root_resource)
)
value = getattr(_new_resrc, key) value = getattr(_new_resrc, key)
setattr(_new_resrc, key, plugin_field.handle_field_put(value, params)) setattr(_new_resrc, key, plugin_field.handle_field_put(value, params))
@@ -574,7 +572,7 @@ class ResourceHandler_RestResourceBase(
if key in self.prev_handler.prev_handler.resource._plugins_: if key in self.prev_handler.prev_handler.resource._plugins_:
plugin_rsrc: ResourcePlugin_RestResourceBase = cast( plugin_rsrc: ResourcePlugin_RestResourceBase = cast(
ResourcePlugin_RestResourceBase, ResourcePlugin_RestResourceBase,
self.prev_handler.prev_handler.resource._plugins_[key](self.req), self.prev_handler.prev_handler.resource._plugins_[key](self.req, self.root_resource),
) )
_new_resrc = plugin_rsrc.handle_dict_elem_put(_new_resrc, params) _new_resrc = plugin_rsrc.handle_dict_elem_put(_new_resrc, params)
# element is within a RestResourceBase # element is within a RestResourceBase
@@ -583,7 +581,7 @@ class ResourceHandler_RestResourceBase(
if key in self.prev_handler.resource._plugins_: if key in self.prev_handler.resource._plugins_:
plugin_rsrc: ResourcePlugin_RestResourceBase = cast( plugin_rsrc: ResourcePlugin_RestResourceBase = cast(
ResourcePlugin_RestResourceBase, ResourcePlugin_RestResourceBase,
self.prev_handler.resource._plugins_[key](self.req), self.prev_handler.resource._plugins_[key](self.req, self.root_resource),
) )
_new_resrc = plugin_rsrc.handle_resource_put(_new_resrc, params) _new_resrc = plugin_rsrc.handle_resource_put(_new_resrc, params)
@@ -634,7 +632,7 @@ class ResourceHandler_simple(
if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_: if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_:
plugin_simple: ResourcePlugin_field = cast( plugin_simple: ResourcePlugin_field = cast(
ResourcePlugin_field, ResourcePlugin_field,
self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req), self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource),
) )
return plugin_simple.handle_field_get(self.resource, params) return plugin_simple.handle_field_get(self.resource, params)
@@ -655,7 +653,7 @@ class ResourceHandler_simple(
# print("PLUGIN FOUND") # print("PLUGIN FOUND")
plugin_simple: ResourcePlugin_field = cast( plugin_simple: ResourcePlugin_field = cast(
ResourcePlugin_field, ResourcePlugin_field,
self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req), self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource),
) )
# print(value) # print(value)
value = plugin_simple.handle_field_put(value, params) value = plugin_simple.handle_field_put(value, params)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, Protocol, runtime_checkable, TYPE_CHECKING from typing import Optional, Generic, TYPE_CHECKING
from abc import abstractmethod from abc import abstractmethod, ABC
from .rest_types import ( from .rest_types import (
_T_DictValues, _T_DictValues,
@@ -12,6 +12,7 @@ from .rest_types import (
from .rest_request import RestRequest from .rest_request import RestRequest
if TYPE_CHECKING or True: if TYPE_CHECKING or True:
from .rest_request_opt import ( from .rest_request_opt import (
RestRequestParams_GET, RestRequestParams_GET,
@@ -26,21 +27,33 @@ if TYPE_CHECKING or True:
) )
class ResourcePlugin(Protocol): class ResourcePlugin(ABC):
def __init__(self, request: RestRequest) -> None: def __init__(self, request: RestRequest, root_resource: "RestResourceBase") -> None:
self.request: RestRequest = request self.__request: RestRequest = request
self.__root_resource: RestRequest = root_resource
def set_resp_cookie(self, name: str, value: str): def user_login(self, user_name: str, user_secret: str) -> str:
return self.__root_resource.user_login(user_name, user_secret, self.__request)
"""
def get_ar_userlogin(self):
print("===========")
return self.__root_resource.get_ar_user_login()
"""
def getr_req_cookie_value(self, key: str) -> Optional[str]:
return self.__request.incoming_cookie[key]
def set_resp_cookie_value(self, key: str, value: str):
# print("AAA") # print("AAA")
# print(name) # print(name)
# print(value) # print(value)
# print(self.cookies) # print(self.cookies)
# print(type(self.cookies)) # print(type(self.cookies))
self.request.outgoing_cookie[name] = value self.__request.outgoing_cookie[key] = value
@runtime_checkable class ResourcePlugin_field(ResourcePlugin, Generic[TV_SupportedRESTFields]):
class ResourcePlugin_field(ResourcePlugin, Protocol[TV_SupportedRESTFields]):
@abstractmethod @abstractmethod
def handle_field_get(self, resource: TV_SupportedRESTFields, params: RestRequestParams_GET) -> TV_SupportedRESTFields: def handle_field_get(self, resource: TV_SupportedRESTFields, params: RestRequestParams_GET) -> TV_SupportedRESTFields:
... ...
@@ -60,8 +73,7 @@ class ResourcePlugin_field_default(ResourcePlugin_field[TV_SupportedRESTFields])
return resource return resource
@runtime_checkable class ResourcePlugin_RestResourceBase(ResourcePlugin, Generic[TV_RestResourceBase]):
class ResourcePlugin_RestResourceBase(ResourcePlugin, Protocol[TV_RestResourceBase]):
@abstractmethod @abstractmethod
def handle_resource_get( def handle_resource_get(
self, self,
@@ -97,8 +109,7 @@ class ResourcePlugin_RestResourceBase_default(ResourcePlugin_RestResourceBase[TV
return resource return resource
@runtime_checkable class ResourcePlugin_dict(ResourcePlugin, Generic[_T_DictKey, _T_DictValues]):
class ResourcePlugin_dict(ResourcePlugin, Protocol[_T_DictKey, _T_DictValues]):
@abstractmethod @abstractmethod
def handle_dict_get_keys( def handle_dict_get_keys(
self, self,

View File

@@ -59,7 +59,7 @@ def init_classes():
resource_with_secret_ACL: TestResource = Field( resource_with_secret_ACL: TestResource = Field(
default=TestResource(), ACL=[ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY)] default=TestResource(), ACL=[ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY)]
) )
resource2: TestResource2 = Field(TestResource2()) resource_ro: TestResource2 = Field(TestResource2())
# this add the classes to globals to allow using them later on # this add the classes to globals to allow using them later on
# => this is only for uinit-testing purpose and is not needed in real use # => this is only for uinit-testing purpose and is not needed in real use
@@ -77,21 +77,23 @@ class Test_RestAPI_ACL(unittest.TestCase):
result = self.testapp.process_request("/", rsrc_verb.GET) result = self.testapp.process_request("/", rsrc_verb.GET)
self.assertEqual(result.get_result(), "{}") self.assertEqual(result.get_result(), "{}")
result = self.testapp.process_request("/resource2", rsrc_verb.GET) result = self.testapp.process_request("/resource_ro", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "3.2.1"}') self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "3.2.1"}')
self.testapp.process_request("/resource2/version", rsrc_verb.PUT, '"6.6.6"') self.testapp.process_request("/resource_ro/version", rsrc_verb.PUT, '"6.6.6"')
result = self.testapp.process_request("/resource2", rsrc_verb.GET) result = self.testapp.process_request("/resource_ro", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}')
with self.assertRaises(RuntimeError): # TODO: custom exception with self.assertRaises(RuntimeError): # TODO: custom exception
self.testapp.process_request("/resource2/version_ro", rsrc_verb.PUT, '"6.6.6"') self.testapp.process_request("/resource_ro/version_ro", rsrc_verb.PUT, '"6.6.6"')
self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3")
with self.assertRaises(RuntimeError): # TODO: custom exception with self.assertRaises(RuntimeError): # TODO: custom exception
self.testapp.process_request("/resource2", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}') self.testapp.process_request("/resource_ro", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}')
self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3")
result = self.testapp.process_request("/resource2", rsrc_verb.GET) result = self.testapp.process_request("/resource_ro", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}') self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}')
def test_subresource(self): def test_subresource(self):

View File

@@ -3,7 +3,7 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
from os import chdir from os import chdir
from pathlib import Path from pathlib import Path
from typing import Optional, Annotated from typing import Optional, Annotated, ClassVar
from pydantic import Field from pydantic import Field
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from time import time, sleep from time import time, sleep
@@ -12,16 +12,17 @@ import socket
import requests import requests
from contextlib import closing from contextlib import closing
from multiprocessing import Process from multiprocessing import Process
from secrets import token_hex
print(__name__) print(__name__)
print(__package__) print(__package__)
from pydantic import BaseModel
from src.pyrestresource import ( from src.pyrestresource import (
register_rest_rootpoint, ACL_target_user,
UserLogin,
RestResourceBase, RestResourceBase,
RestResourceBaseLogin,
register_rest_rootpoint,
rsrc_verb, rsrc_verb,
RestRequestParams_GET, RestRequestParams_GET,
RestRequestParams_POST, RestRequestParams_POST,
@@ -42,58 +43,25 @@ chdir(testdir_path.parent.resolve())
# to allow mock-ing, all the tested classes are in a function # to allow mock-ing, all the tested classes are in a function
def init_classes(): def init_classes():
class UserLogin(BaseModel): user_CHACHA = UserLogin(username="chacha", secret="123456")
username: str
secret: str
token: Optional[str] = None
class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default): class TestResourceACL(RestResourceBase):
ar_UserLogin: list[UserLogin] = [UserLogin(username="chacha", secret="123456")] test_field: Optional[str] = Field(
"ORIGIN_VALUE",
def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login:
print("hook GET")
print(resource)
print(params)
return resource
def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login:
print("hook PUT")
print(resource.username)
print(resource.secret)
for _UserLogin in self.ar_UserLogin:
if _UserLogin.username == resource.username and _UserLogin.secret == resource.secret:
print("user connected")
_UserLogin.token = token_hex(16)
self.set_resp_cookie("test", _UserLogin.token)
print(f"generated token: {_UserLogin.token}")
return resource
print("login NOT found")
# print(resource)
# print(resource.username)
# print(resource.secret)
# print(params)
return resource
class Login(RestResourceBase):
username: Optional[str] = Field(None)
secret: Optional[str] = Field(
None,
exclude=True,
ACL=[ ACL=[
ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW), ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user(name="chacha"), rule=ACL_rule.ALLOW),
ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY), ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY),
], ],
) )
@register_rest_rootpoint @register_rest_rootpoint
class RootApp(RestResourceBase): class RootApp(RestResourceBaseLogin):
login: Login = Field(default=Login(), plugin=ResourcePlugin_Login) _ar_user_login: ClassVar[list[UserLogin]] = [user_CHACHA]
test_resource: TestResourceACL = TestResourceACL()
# this add the classes to globals to allow using them later on # this add the classes to globals to allow using them later on
# => this is only for uinit-testing purpose and is not needed in real use # => this is only for uinit-testing purpose and is not needed in real use
globals()[Login.__name__] = Login globals()[TestResourceACL.__name__] = TestResourceACL
globals()[RootApp.__name__] = RootApp globals()[RootApp.__name__] = RootApp
@@ -116,6 +84,61 @@ class Test_RestAPI_LOGIN(unittest.TestCase):
init_classes() init_classes()
self.testapp = RootApp() self.testapp = RootApp()
def test_access(self):
ip, port = find_free_port()
print(f"ip1={ip}")
print(f"port1={port}")
proc = Process(
target=launch_server,
args=(
ip,
port,
),
)
proc.start()
sleep(1)
s = requests.Session()
try:
# before modification read
response = s.get(
f"http://{ip}:{port}/test_resource/test_field",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), "ORIGIN_VALUE")
# try unauthenticated write
response = s.put(f"http://{ip}:{port}/test_resource/test_field", json='"TEST SET VALUE"')
self.assertEqual(response.status_code, 500)
# check not modified
response = s.get(
f"http://{ip}:{port}/test_resource/test_field",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), "ORIGIN_VALUE")
# login
response = s.put(
f"http://{ip}:{port}/login",
json={"username": "chacha", "secret": "123456"},
)
self.assertEqual(response.status_code, 201)
# authenticated write
response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE")
self.assertEqual(response.status_code, 201)
# modified
response = s.get(
f"http://{ip}:{port}/test_resource/test_field",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), "TEST SET VALUE")
finally:
proc.terminate()
s.close()
def test_login(self): def test_login(self):
result = self.testapp.process_request("/login", rsrc_verb.GET) result = self.testapp.process_request("/login", rsrc_verb.GET)
print("*****************") print("*****************")
@@ -172,6 +195,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase):
json={"username": "chacha", "secret": "123456"}, json={"username": "chacha", "secret": "123456"},
) )
print(response) print(response)
print("??????")
print(response.headers) print(response.headers)
self.assertEqual(response.status_code, 201) self.assertEqual(response.status_code, 201)