continuing implementation of login and session

This commit is contained in:
cclecle
2023-11-05 15:38:08 +00:00
parent 346ff649ec
commit f00cf7b4b2
11 changed files with 325 additions and 113 deletions

View File

@@ -1,6 +1,7 @@
eclipse.preferences.version=1
encoding//src/pyrestresource/__init__.py=utf-8
encoding//src/pyrestresource/__metadata__.py=utf-8
encoding//src/pyrestresource/rest_login.py=utf-8
encoding//src/pyrestresource/rest_resource.py=utf-8
encoding//src/pyrestresource/rest_resource_handler_walker.py=utf-8
encoding/<project>=UTF-8

View File

@@ -54,3 +54,4 @@ from .rest_resource_plugin import (
ResourcePlugin_dict_default,
)
from .rest_ACL import ACL_target_user, ACL_target_group, ACL_target_group_Any, ACL_record, ACL_rule
from .rest_login import RestResourceBaseLogin, UserLogin

View File

@@ -15,3 +15,17 @@ class _JSONEncoder(json.JSONEncoder):
# if the obj is uuid, we simply return the value of uuid
return str(o)
return json.JSONEncoder.default(self, o)
def parse_dict_cookies(cookies: str) -> dict[str, str]:
result = {}
for item in cookies.split(";"):
item = item.strip()
if not item:
continue
if "=" not in item:
result[item] = None
continue
name, value = item.split("=", 1)
result[name] = value
return result

View File

@@ -22,10 +22,6 @@ class ACL_target_group(ACL_target):
name: str
class ACL_target_group_Annonymous(ACL_target):
name: str = "__ANNONYMOUS__"
class ACL_target_group_Any(ACL_target_group):
name: str = "__ANY__"

View File

@@ -0,0 +1,127 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# pyrestresource(c) by chacha
#
# pyrestresource is licensed under a
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Unported License.
#
# You should have received a copy of the license along with this
# work. If not, see <https://creativecommons.org/licenses/by-nc-sa/4.0/>.
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
"""CLI interface module"""
from __future__ import annotations
from typing import Optional, ClassVar, TYPE_CHECKING
from secrets import token_hex, compare_digest
from datetime import datetime
from pydantic import BaseModel, Field
from .rest_types import rsrc_verb
from .rest_resource import RestResourceBase
from .rest_request import RestRequest, RestRequestParams_GET
from .rest_ACL import ACL_record, ACL_target_group_Any, ACL_rule
if TYPE_CHECKING or True:
from .rest_resource_plugin import ResourcePlugin_RestResourceBase_default
class UserLogin(BaseModel):
username: str
secret: str
class UserSession(BaseModel):
last_update: datetime
user_login: UserLogin
host: Optional[str]
class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default):
ar_UserLogin: list[UserLogin] = []
def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login:
print("hook GET")
print(resource)
print(params)
return resource
def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login:
print("hook PUT")
# print(self.get_ar_userlogin())
print(resource.username)
print(resource.secret)
token = self.user_login(resource.username, resource.secret)
self.set_resp_cookie_value("Authorization", f"Bearer {token}")
return resource
class Login(RestResourceBase):
username: Optional[str] = Field(None)
secret: Optional[str] = Field(
None,
exclude=True,
ACL=[
ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW),
ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY),
],
)
class RestResourceBaseLogin(RestResourceBase):
_ar_user_login: ClassVar[list[UserLogin]] = []
_ar_user_session: dict[str, UserSession] = {}
_max_session_time_minutes: ClassVar[int] = 20
login: Login = Field(default=Login(), plugin=ResourcePlugin_Login)
def _process_request_session(self, request: RestRequest) -> None:
auth_cookie = request.get_cookie("Authorization")
if auth_cookie != None:
if auth_cookie in self._ar_user_session:
print("USER SESSION FOUND !")
print(self._ar_user_session[auth_cookie].user_login.username)
print(auth_cookie)
time_diff_min = (datetime.now() - self._ar_user_session[auth_cookie].last_update).total_seconds() / 60
if time_diff_min > self._max_session_time_minutes:
del self._ar_user_session[auth_cookie]
raise RuntimeError("session timeout ! (session reseted)")
request.set_user(self._ar_user_session[auth_cookie].user_login.username)
return
print("Invalid session")
return
print("non-connected user")
def user_login(self, user_name: str, user_secret: str, request: RestRequest) -> str:
already_failed: bool = False
for iter_user_login in self._ar_user_login:
username_ok: bool = compare_digest(user_name, iter_user_login.username)
secret_ok: bool = compare_digest(user_secret, iter_user_login.secret)
if username_ok is True:
if secret_ok is True and not already_failed:
return self._register_user_session(iter_user_login, request)
else:
already_failed = True
else:
pass
pass
if already_failed:
raise RuntimeError("Wrong auth") # TODO: specific exception
def _register_user_session(self, user_login: UserLogin, request: RestRequest) -> str:
token = token_hex(16)
new_user_session = UserSession(last_update=datetime.now(), user_login=user_login, host=request.get_host())
self._ar_user_session[f"Bearer {token}"] = new_user_session
return token

View File

@@ -3,11 +3,14 @@
from __future__ import annotations
from typing import (
Any,
Optional,
Generic,
)
from re import sub
from urllib.parse import urlparse, parse_qs
from http.cookies import SimpleCookie
from pydantic import BaseModel, Field
from typeguard import check_type
@@ -26,7 +29,8 @@ from .rest_request_opt import (
_T_RestRequestParams_PUT,
)
from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group, ACL_target_group_Annonymous
from .rest_ACL import ACL_target_user, ACL_target_user_Annonymous, ACL_target_group
from .helpers import parse_dict_cookies
class RequestFactory(
@@ -114,6 +118,8 @@ class RestRequest(Generic[_T_RestRequestParams]):
self.url: str
self.verb: rsrc_verb
self.data: dict
self.raw_headers: list[Any]
self.headers: dict[str, None | str | dict[str, None | str]] = {"host": None, "cookie": {}}
self._saved_url_params: dict
self.ReqParams: _T_RestRequestParams = type_request_params()
self.url_stack: list[str]
@@ -122,7 +128,7 @@ class RestRequest(Generic[_T_RestRequestParams]):
self.incoming_cookie: dict[str, str] = incoming_cookie
self.outgoing_cookie: dict[str, str] = outgoing_cookie
self.user: ACL_target_user = ACL_target_user_Annonymous()
self.group: ACL_target_group = ACL_target_group_Annonymous()
self.groups: list[ACL_target_group] = []
self.result: Optional[str] = None
# = or create a fresh one =
@@ -151,6 +157,24 @@ class RestRequest(Generic[_T_RestRequestParams]):
self._saved_url_stack = self.url_stack.copy()
self.url_stack_index = 0
def set_headers(self, headers: list[Any]) -> None:
self.raw_headers = headers
for elem in self.raw_headers:
if elem[0] == b"host":
self.headers["host"] = elem[1].decode("utf-8")
# elif elem[0] == b"user-agent":
# self.headers["user-agent"] = elem[1].decode("utf-8")
elif elem[0] == b"cookie":
self.headers["cookie"] = parse_dict_cookies(elem[1].decode("utf-8"))
def get_cookie(self, key: str) -> str | None:
if key not in self.headers["cookie"]:
return None
return self.headers["cookie"][key]
def get_host(self) -> str:
print(self.headers["host"])
def set_result(self, result: str):
self.result = result
@@ -160,8 +184,8 @@ class RestRequest(Generic[_T_RestRequestParams]):
def set_user(self, user: ACL_target_user):
self.user: ACL_target_user = user
def set_group(self, group: ACL_target_group):
self.group: ACL_target_group = group
def add_group(self, group: ACL_target_group):
self.groups.append(group)
def update_ReqParams(self, type_request_params: type[_T_RestRequestParams]):
self.ReqParams = type_request_params(**self._saved_url_params)

View File

@@ -15,6 +15,7 @@ from __future__ import annotations
from abc import ABC
from typing import (
Any,
cast,
ClassVar,
get_args,
@@ -39,7 +40,6 @@ from .rest_ACL import (
ACL_target_user,
ACL_target_group,
ACL_target_user_Annonymous,
ACL_target_group_Annonymous,
ACL_target_group_Any,
ACL_rule,
)
@@ -219,8 +219,8 @@ class RestResourceWalker_Sub_RestResourceBase__tree_init(RestResourceWalker_Sub_
):
if "plugin" in self.resource.json_schema_extra:
plugin_resource: ResourcePlugin_RestResourceBase = self.resource.json_schema_extra["plugin"]
if not isinstance(plugin_resource, ResourcePlugin_RestResourceBase):
raise RuntimeError("Wrong plugin signature provided")
if not issubclass(plugin_resource, ResourcePlugin_RestResourceBase):
raise RuntimeError(f"Wrong plugin signature provided for {plugin_resource} : {type(plugin_resource)}")
self.parent.annotation._plugins_[self.resource_name] = plugin_resource
# print("ADD RESOURCE PLUGIN")
@@ -246,7 +246,7 @@ def register_rest_rootpoint(klass: type[RestResourceBase]):
class RestResourceBase(ABC, BaseModel, validate_assignment=True):
_resp_cookies: ClassVar[dict[str, str]] = dict()
# _resp_cookies: ClassVar[dict[str, str]] = {}
_dict_key_type_: ClassVar[dict[str, T_T_DictKey]] = {}
_dict_value_type_: ClassVar[dict[str, T_T_DictValues]] = {}
_model_dump_excluded_: ClassVar[dict[str, bool]] = {}
@@ -264,43 +264,45 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
]
] = {}
def _check_acl(self, user: ACL_target_user, group: ACL_target_group, verb: rsrc_verb, field: str, is_self: bool = True):
# print(f"evaluate self ACLs rule: {self._ACL_record_}")
def _check_acl(self, user: str, groups: list[ACL_target_group], verb: rsrc_verb, field: str, is_self: bool = True):
print(f"evaluate self ACLs rule: {self._ACL_record_}")
print(f"user: {user}")
print(f"groups: {groups}")
if is_self and verb is rsrc_verb.GET and self.model_fields[field].exclude is True:
# print("ALLOWED (excluded field)")
return
for acl in self._ACL_record_[field]:
# print(f"evaluate ACL rule: {acl}")
print(f"evaluate ACL rule: {acl}")
if verb in acl.verbs:
if isinstance(acl.target, ACL_target_user):
if user == acl.target:
if user == acl.target.name:
if acl.rule is ACL_rule.ALLOW:
# print("ALLOWED (user)")
print("ALLOWED (user)")
return
raise RuntimeError(f"Not allowed access detected: {field}")
elif isinstance(acl.target, ACL_target_group):
if group == acl.target or acl.target == ACL_target_group_Any():
if acl.target.name in groups or isinstance(acl.target, ACL_target_group_Any):
if acl.rule is ACL_rule.ALLOW:
# print("ALLOWED (group)")
print("ALLOWED (group)")
return
raise RuntimeError(f"Not allowed access detected: {field}")
else:
raise RuntimeError(f"Wrong ACL target type: {field}")
# print("ALLOWED (Default)")
print("ALLOWED (Default)")
def check_acl_field(self, request: RestRequest, req_index: int = 0) -> None:
"""Check ACL on requested field access"""
self._check_acl(request.user, request.group, request.get_verb(), request.get_resource_origin(req_index), False)
self._check_acl(request.user, request.groups, request.get_verb(), request.get_resource_origin(req_index), False)
def check_acl_self(self, request: RestRequest, new_data: Optional[dict[str, _T_SupportedRESTFields]]) -> None:
"""Check ACL on requested field operation (involving checking sub-fields)"""
if request.get_verb() is rsrc_verb.GET:
for key in self.model_fields.keys():
self._check_acl(request.user, request.group, rsrc_verb.GET, key)
self._check_acl(request.user, request.groups, rsrc_verb.GET, key)
elif request.get_verb() is rsrc_verb.PUT:
for key in new_data.keys():
if key in self.model_fields:
self._check_acl(request.user, request.group, rsrc_verb.PUT, key)
self._check_acl(request.user, request.groups, rsrc_verb.PUT, key)
else:
raise RuntimeError("Incompatible verb")
@@ -324,21 +326,24 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
async def __call__(self, scope, receive, send):
assert scope["type"] == "http"
method = scope["method"]
assert method in ["GET", "DELETE", "PUT", "POST"]
if b"content-type" in scope["headers"]:
assert scope["headers"][b"content-type"] == b"application/json"
# import pprint
import pprint
# print("----REC HEADER ---")
# pprint.pprint(scope["headers"])
print("----REC HEADER ---")
pprint.pprint(scope["headers"])
body = await self.read_body(receive)
verb = rsrc_verb[scope["method"]]
request: RestRequest = self.process_request(
scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8")
scope["path"], rsrc_verb[scope["method"]], body.decode("utf-8"), scope["query_string"].decode("utf-8"), scope["headers"]
)
assert request != None
@@ -373,12 +378,16 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
}
)
def _process_request_session(self, request: RestRequest) -> None:
pass
def process_request(
self,
url: str,
verb: rsrc_verb = rsrc_verb.GET,
data_json: Optional[str] = None,
query_string: Optional[str] = None,
headers: Optional[list[Any]] = None,
) -> RestRequest:
from .rest_resource_handler import (
ResourceHandler,
@@ -389,11 +398,16 @@ class RestResourceBase(ABC, BaseModel, validate_assignment=True):
if data_json:
data = json.loads(data_json)
# creating the root handler
ressource_handler: ResourceHandler = ResourceHandler_RestResourceBase(self, url, verb, data, query_string)
# preparing request & session
request: RestRequest = ressource_handler.get_request()
assert request != None
if headers is not None:
request.set_headers(headers)
self._process_request_session(request)
# processing the verb
result = ressource_handler.process_verb()
# print("OOO")

View File

@@ -23,7 +23,6 @@ from .rest_ACL import (
ACL_target_user,
ACL_target_group,
ACL_target_user_Annonymous,
ACL_target_group_Annonymous,
ACL_target_group_Any,
ACL_rule,
)
@@ -101,6 +100,7 @@ class ResourceHandler(
self.next_handler: Optional[ResourceHandler] = None
self.saved_url: list[str] = []
self.resource: _T_Resource = resource
self.root_resource: _T_Resource = resource if prev_handler is None else prev_handler.root_resource
self.req: RestRequest
if prev_handler is not None:
self.prev_handler = prev_handler
@@ -484,14 +484,6 @@ class ResourceHandler_RestResourceBase(
if len(self.req.get_url_stack()) == 0: # destination reached
if self.resource.model_fields[self.req.get_resource_origin(0)].exclude is True and self.req.get_verb() is rsrc_verb.GET:
raise RuntimeError(f"Not allowed READ access detected: {self.req.get_url_stack()}")
""" # not sure init_var has the expected behavior (read_only)
if self.resource.model_fields[self.req.get_resource_origin(0)].init_var is True and self.req.get_verb() in [
rsrc_verb.POST,
rsrc_verb.PUT,
rsrc_verb.DELETE,
]:
raise RuntimeError(f"Not allowed WRITE access detected: {self.req.get_url_stack()}")
"""
def _handle_process_get(self, params) -> RestResourceBase:
# print(f"{type(self).__name__}->_process_get()")
@@ -504,11 +496,15 @@ class ResourceHandler_RestResourceBase(
for key, attr in self.resource.model_fields.items():
if key in self.resource._plugins_:
if isinstance(self.resource._plugins_[key], ResourcePlugin_field):
plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, self.resource._plugins_[key](self.req))
plugin_field: ResourcePlugin_field = cast(
ResourcePlugin_field, self.resource._plugins_[key](self.req, self.root_resource)
)
value = getattr(self.resource, key)
setattr(self.resource, key, plugin_field.handle_field_get(value, params))
elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase):
plugin_field: ResourcePlugin_field = cast(ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req))
plugin_field: ResourcePlugin_field = cast(
ResourcePlugin_RestResourceBase, self.resource._plugins_[key](self.req, self.root_resource)
)
value = getattr(self.resource, key)
setattr(self.resource, key, plugin_field.handle_resource_get(value, params))
@@ -530,14 +526,14 @@ class ResourceHandler_RestResourceBase(
if isinstance(self.resource._plugins_[key], ResourcePlugin_field):
plugin_rsrc: ResourcePlugin_RestResourceBase = cast(
ResourcePlugin_RestResourceBase,
self.resource._plugins_[key](self.req),
self.resource._plugins_[key](self.req, self.root_resource),
)
value = plugin_rsrc.handle_field_get(value, params)
elif isinstance(self.resource._plugins_[key], ResourcePlugin_RestResourceBase):
plugin_rsrc: ResourcePlugin_RestResourceBase = cast(
ResourcePlugin_RestResourceBase,
self.resource._plugins_[key](self.req),
self.resource._plugins_[key](self.req, self.root_resource),
)
value = plugin_rsrc.handle_resource_get(value, params)
@@ -559,7 +555,9 @@ class ResourceHandler_RestResourceBase(
for key, attr in _new_resrc.model_fields.items():
if key in _new_resrc._plugins_:
if isinstance(_new_resrc._plugins_[key], ResourcePlugin_field):
plugin_field: ResourcePlugin_field = cast(ResourcePlugin_field, _new_resrc._plugins_[key](self.req))
plugin_field: ResourcePlugin_field = cast(
ResourcePlugin_field, _new_resrc._plugins_[key](self.req, self.root_resource)
)
value = getattr(_new_resrc, key)
setattr(_new_resrc, key, plugin_field.handle_field_put(value, params))
@@ -574,7 +572,7 @@ class ResourceHandler_RestResourceBase(
if key in self.prev_handler.prev_handler.resource._plugins_:
plugin_rsrc: ResourcePlugin_RestResourceBase = cast(
ResourcePlugin_RestResourceBase,
self.prev_handler.prev_handler.resource._plugins_[key](self.req),
self.prev_handler.prev_handler.resource._plugins_[key](self.req, self.root_resource),
)
_new_resrc = plugin_rsrc.handle_dict_elem_put(_new_resrc, params)
# element is within a RestResourceBase
@@ -583,7 +581,7 @@ class ResourceHandler_RestResourceBase(
if key in self.prev_handler.resource._plugins_:
plugin_rsrc: ResourcePlugin_RestResourceBase = cast(
ResourcePlugin_RestResourceBase,
self.prev_handler.resource._plugins_[key](self.req),
self.prev_handler.resource._plugins_[key](self.req, self.root_resource),
)
_new_resrc = plugin_rsrc.handle_resource_put(_new_resrc, params)
@@ -634,7 +632,7 @@ class ResourceHandler_simple(
if self.req.get_resource_origin(1) in self.prev_handler.resource._plugins_:
plugin_simple: ResourcePlugin_field = cast(
ResourcePlugin_field,
self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req),
self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource),
)
return plugin_simple.handle_field_get(self.resource, params)
@@ -655,7 +653,7 @@ class ResourceHandler_simple(
# print("PLUGIN FOUND")
plugin_simple: ResourcePlugin_field = cast(
ResourcePlugin_field,
self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req),
self.prev_handler.resource._plugins_[self.req.get_resource_origin(1)](self.req, self.root_resource),
)
# print(value)
value = plugin_simple.handle_field_put(value, params)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Optional, Protocol, runtime_checkable, TYPE_CHECKING
from abc import abstractmethod
from typing import Optional, Generic, TYPE_CHECKING
from abc import abstractmethod, ABC
from .rest_types import (
_T_DictValues,
@@ -12,6 +12,7 @@ from .rest_types import (
from .rest_request import RestRequest
if TYPE_CHECKING or True:
from .rest_request_opt import (
RestRequestParams_GET,
@@ -26,21 +27,33 @@ if TYPE_CHECKING or True:
)
class ResourcePlugin(Protocol):
def __init__(self, request: RestRequest) -> None:
self.request: RestRequest = request
class ResourcePlugin(ABC):
def __init__(self, request: RestRequest, root_resource: "RestResourceBase") -> None:
self.__request: RestRequest = request
self.__root_resource: RestRequest = root_resource
def set_resp_cookie(self, name: str, value: str):
def user_login(self, user_name: str, user_secret: str) -> str:
return self.__root_resource.user_login(user_name, user_secret, self.__request)
"""
def get_ar_userlogin(self):
print("===========")
return self.__root_resource.get_ar_user_login()
"""
def getr_req_cookie_value(self, key: str) -> Optional[str]:
return self.__request.incoming_cookie[key]
def set_resp_cookie_value(self, key: str, value: str):
# print("AAA")
# print(name)
# print(value)
# print(self.cookies)
# print(type(self.cookies))
self.request.outgoing_cookie[name] = value
self.__request.outgoing_cookie[key] = value
@runtime_checkable
class ResourcePlugin_field(ResourcePlugin, Protocol[TV_SupportedRESTFields]):
class ResourcePlugin_field(ResourcePlugin, Generic[TV_SupportedRESTFields]):
@abstractmethod
def handle_field_get(self, resource: TV_SupportedRESTFields, params: RestRequestParams_GET) -> TV_SupportedRESTFields:
...
@@ -60,8 +73,7 @@ class ResourcePlugin_field_default(ResourcePlugin_field[TV_SupportedRESTFields])
return resource
@runtime_checkable
class ResourcePlugin_RestResourceBase(ResourcePlugin, Protocol[TV_RestResourceBase]):
class ResourcePlugin_RestResourceBase(ResourcePlugin, Generic[TV_RestResourceBase]):
@abstractmethod
def handle_resource_get(
self,
@@ -97,8 +109,7 @@ class ResourcePlugin_RestResourceBase_default(ResourcePlugin_RestResourceBase[TV
return resource
@runtime_checkable
class ResourcePlugin_dict(ResourcePlugin, Protocol[_T_DictKey, _T_DictValues]):
class ResourcePlugin_dict(ResourcePlugin, Generic[_T_DictKey, _T_DictValues]):
@abstractmethod
def handle_dict_get_keys(
self,

View File

@@ -59,7 +59,7 @@ def init_classes():
resource_with_secret_ACL: TestResource = Field(
default=TestResource(), ACL=[ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY)]
)
resource2: TestResource2 = Field(TestResource2())
resource_ro: TestResource2 = Field(TestResource2())
# this add the classes to globals to allow using them later on
# => this is only for uinit-testing purpose and is not needed in real use
@@ -77,21 +77,23 @@ class Test_RestAPI_ACL(unittest.TestCase):
result = self.testapp.process_request("/", rsrc_verb.GET)
self.assertEqual(result.get_result(), "{}")
result = self.testapp.process_request("/resource2", rsrc_verb.GET)
result = self.testapp.process_request("/resource_ro", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "3.2.1"}')
self.testapp.process_request("/resource2/version", rsrc_verb.PUT, '"6.6.6"')
self.testapp.process_request("/resource_ro/version", rsrc_verb.PUT, '"6.6.6"')
result = self.testapp.process_request("/resource2", rsrc_verb.GET)
result = self.testapp.process_request("/resource_ro", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}')
with self.assertRaises(RuntimeError): # TODO: custom exception
self.testapp.process_request("/resource2/version_ro", rsrc_verb.PUT, '"6.6.6"')
self.testapp.process_request("/resource_ro/version_ro", rsrc_verb.PUT, '"6.6.6"')
self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3")
with self.assertRaises(RuntimeError): # TODO: custom exception
self.testapp.process_request("/resource2", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}')
self.testapp.process_request("/resource_ro", rsrc_verb.PUT, '{"version_ro": "6.6.1", "version": "6.6.2"}')
self.assertEqual(self.testapp.resource_ro.version_ro, "1.2.3")
result = self.testapp.process_request("/resource2", rsrc_verb.GET)
result = self.testapp.process_request("/resource_ro", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version_ro": "1.2.3", "version": "6.6.6"}')
def test_subresource(self):

View File

@@ -3,7 +3,7 @@ import unittest
from unittest.mock import patch
from os import chdir
from pathlib import Path
from typing import Optional, Annotated
from typing import Optional, Annotated, ClassVar
from pydantic import Field
from uuid import UUID, uuid4
from time import time, sleep
@@ -12,16 +12,17 @@ import socket
import requests
from contextlib import closing
from multiprocessing import Process
from secrets import token_hex
print(__name__)
print(__package__)
from pydantic import BaseModel
from src.pyrestresource import (
register_rest_rootpoint,
ACL_target_user,
UserLogin,
RestResourceBase,
RestResourceBaseLogin,
register_rest_rootpoint,
rsrc_verb,
RestRequestParams_GET,
RestRequestParams_POST,
@@ -42,58 +43,25 @@ chdir(testdir_path.parent.resolve())
# to allow mock-ing, all the tested classes are in a function
def init_classes():
class UserLogin(BaseModel):
username: str
secret: str
token: Optional[str] = None
user_CHACHA = UserLogin(username="chacha", secret="123456")
class ResourcePlugin_Login(ResourcePlugin_RestResourceBase_default):
ar_UserLogin: list[UserLogin] = [UserLogin(username="chacha", secret="123456")]
def handle_resource_get(self, resource: Login, params: RestRequestParams_GET) -> Login:
print("hook GET")
print(resource)
print(params)
return resource
def handle_resource_put(self, resource: Login, params: RestRequestParams_GET) -> Login:
print("hook PUT")
print(resource.username)
print(resource.secret)
for _UserLogin in self.ar_UserLogin:
if _UserLogin.username == resource.username and _UserLogin.secret == resource.secret:
print("user connected")
_UserLogin.token = token_hex(16)
self.set_resp_cookie("test", _UserLogin.token)
print(f"generated token: {_UserLogin.token}")
return resource
print("login NOT found")
# print(resource)
# print(resource.username)
# print(resource.secret)
# print(params)
return resource
class Login(RestResourceBase):
username: Optional[str] = Field(None)
secret: Optional[str] = Field(
None,
exclude=True,
class TestResourceACL(RestResourceBase):
test_field: Optional[str] = Field(
"ORIGIN_VALUE",
ACL=[
ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.ALLOW),
ACL_record(verbs=[rsrc_verb.GET], target=ACL_target_group_Any(), rule=ACL_rule.DENY),
ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_user(name="chacha"), rule=ACL_rule.ALLOW),
ACL_record(verbs=[rsrc_verb.PUT], target=ACL_target_group_Any(), rule=ACL_rule.DENY),
],
)
@register_rest_rootpoint
class RootApp(RestResourceBase):
login: Login = Field(default=Login(), plugin=ResourcePlugin_Login)
class RootApp(RestResourceBaseLogin):
_ar_user_login: ClassVar[list[UserLogin]] = [user_CHACHA]
test_resource: TestResourceACL = TestResourceACL()
# this add the classes to globals to allow using them later on
# => this is only for uinit-testing purpose and is not needed in real use
globals()[Login.__name__] = Login
globals()[TestResourceACL.__name__] = TestResourceACL
globals()[RootApp.__name__] = RootApp
@@ -116,6 +84,61 @@ class Test_RestAPI_LOGIN(unittest.TestCase):
init_classes()
self.testapp = RootApp()
def test_access(self):
ip, port = find_free_port()
print(f"ip1={ip}")
print(f"port1={port}")
proc = Process(
target=launch_server,
args=(
ip,
port,
),
)
proc.start()
sleep(1)
s = requests.Session()
try:
# before modification read
response = s.get(
f"http://{ip}:{port}/test_resource/test_field",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), "ORIGIN_VALUE")
# try unauthenticated write
response = s.put(f"http://{ip}:{port}/test_resource/test_field", json='"TEST SET VALUE"')
self.assertEqual(response.status_code, 500)
# check not modified
response = s.get(
f"http://{ip}:{port}/test_resource/test_field",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), "ORIGIN_VALUE")
# login
response = s.put(
f"http://{ip}:{port}/login",
json={"username": "chacha", "secret": "123456"},
)
self.assertEqual(response.status_code, 201)
# authenticated write
response = s.put(f"http://{ip}:{port}/test_resource/test_field", json="TEST SET VALUE")
self.assertEqual(response.status_code, 201)
# modified
response = s.get(
f"http://{ip}:{port}/test_resource/test_field",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), "TEST SET VALUE")
finally:
proc.terminate()
s.close()
def test_login(self):
result = self.testapp.process_request("/login", rsrc_verb.GET)
print("*****************")
@@ -172,6 +195,7 @@ class Test_RestAPI_LOGIN_Web(unittest.TestCase):
json={"username": "chacha", "secret": "123456"},
)
print(response)
print("??????")
print(response.headers)
self.assertEqual(response.status_code, 201)